fix: remove typing errors (mypy)

This commit is contained in:
iriley 2024-04-29 13:58:35 +02:00
parent b43033e6bf
commit 2c6232a1bf
12 changed files with 50 additions and 49 deletions

View File

@ -14,7 +14,7 @@ from cv_analysis.utils.postprocessing import remove_included
from cv_analysis.utils.structures import Rectangle
def detect_figures(image: np.array):
def detect_figures(image: np.ndarray):
max_area = image.shape[0] * image.shape[1] * 0.99
min_area = 5000
max_width_to_height_ratio = 6
@ -24,9 +24,10 @@ def detect_figures(image: np.array):
cnts = detect_large_coherent_structures(image)
cnts = filter(figure_filter, cnts)
rects = map(cv2.boundingRect, cnts)
rects = map(Rectangle.from_xywh, rects)
rects = remove_included(rects)
# rects = map(compose(Rectangle.from_xywh, cv2.boundingRect), (cnts))
bounding_rects = map(cv2.boundingRect, cnts)
rects: list[Rectangle] = remove_included(map(Rectangle.from_xywh, rects))
return rects

View File

@ -2,7 +2,7 @@ import cv2
import numpy as np
def detect_large_coherent_structures(image: np.array):
def detect_large_coherent_structures(image: np.ndarray):
"""Detects large coherent structures on an image.
Expects an image with binary color space (e.g. threshold applied).

View File

@ -48,7 +48,7 @@ def fill_in_component_area(image, rect):
return ~image
def parse_layout(image: np.array):
def parse_layout(image: np.ndarray):
image = image.copy()
image_ = image.copy()
@ -77,8 +77,7 @@ def parse_layout(image: np.array):
rects = list(map(Rectangle.from_xywh, rects))
rects = remove_included(rects)
rects = map(lambda r: r.xywh(), rects)
rects = connect_related_rects2(rects)
rects = connect_related_rects2(map(lambda r: r.xywh(), rects))
rects = list(map(Rectangle.from_xywh, rects))
rects = remove_included(rects)

View File

@ -2,7 +2,7 @@ from functools import partial
import cv2
import numpy as np
from iteration_utilities import first, starfilter
from iteration_utilities import first, starfilter # type: ignore
from cv_analysis.utils.filters import is_boxy, is_filled, is_large_enough
from cv_analysis.utils.visual_logging import vizlogger
@ -12,7 +12,7 @@ def is_likely_redaction(contour, hierarchy, min_area):
return is_filled(hierarchy) and is_boxy(contour) and is_large_enough(contour, min_area)
def find_redactions(image: np.array, min_normalized_area=200000):
def find_redactions(image: np.ndarray, min_normalized_area=200000):
vizlogger.debug(image, "redactions01_start.png")
min_normalized_area /= 200 # Assumes 200 DPI PDF -> image conversion resolution
@ -29,13 +29,12 @@ def find_redactions(image: np.array, min_normalized_area=200000):
contours, hierarchies = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
try:
contours = map(
return list(map(
first,
starfilter(
partial(is_likely_redaction, min_area=min_normalized_area),
zip(contours, hierarchies[0]),
),
)
return list(contours)
))
except:
return []

View File

@ -2,10 +2,10 @@ from dataclasses import asdict
from operator import itemgetter, truth
from typing import Callable, Generator
from funcy import flatten, lmap
from pdf2img.conversion import convert_pages_to_images
from pdf2img.default_objects.image import ImageInfo, ImagePlus
from pdf2img.default_objects.rectangle import RectanglePlus
from funcy import flatten, lmap # type: ignore
from pdf2img.conversion import convert_pages_to_images # type: ignore
from pdf2img.default_objects.image import ImageInfo, ImagePlus # type: ignore
from pdf2img.default_objects.rectangle import RectanglePlus # type: ignore
from cv_analysis.figure_detection.figure_detection import detect_figures
from cv_analysis.table_inference import infer_lines

View File

@ -4,7 +4,7 @@ from typing import Callable, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
from cv2 import cv2
import cv2
from kn_utils.logging import logger # type: ignore
from numpy import ndarray as Array
from scipy.stats import norm # type: ignore
@ -148,7 +148,7 @@ def filter_fp_col_lines(line_list: list[int], filt_sums: Array) -> list[int]:
return line_list
def get_lines_either(table_array: Array, horizontal=True) -> Array:
def get_lines_either(table_array: Array, horizontal=True) -> list[int]:
key = "row" if horizontal else "col"
sums = np.mean(table_array, axis=int(horizontal))
@ -162,9 +162,7 @@ def get_lines_either(table_array: Array, horizontal=True) -> Array:
filtered_sums = filter_array(filtered_sums, FILTERS[key][2])
filtered_sums = filter_array(filtered_sums, FILTERS[key][3])
lines = list(
np.where((filtered_sums[1:-1] > filtered_sums[:-2]) * (filtered_sums[1:-1] > filtered_sums[2:]))[0] + 1
)
lines = list(np.where((filtered_sums[1:-1] > filtered_sums[:-2]) * (filtered_sums[1:-1] > filtered_sums[2:]))[0] + 1)
if not horizontal:
lines = filter_fp_col_lines(lines, filtered_sums)
@ -176,7 +174,7 @@ def img_bytes_to_array(img_bytes: bytes) -> Array:
return img_np
def infer_lines(img: Array) -> dict[str, list[dict[str, int]] | list[dict[str, int]]]:
def infer_lines(img: Array) -> dict[str, dict[str, int] | list[dict[str, int]]]:
cv2.imwrite("/tmp/table.png", img)
_, img = cv2.threshold(img, 220, 255, cv2.THRESH_BINARY)
cv2.imwrite("/tmp/table_bin.png", img)

View File

@ -1,5 +1,5 @@
import numpy as np
from cv2 import cv2
import cv2
from funcy import lfilter, lmap # type: ignore
from cv_analysis.layout_parsing import parse_layout
@ -208,7 +208,7 @@ def detect_endpoints(image: np.ndarray, is_horizontal: bool) -> list[tuple[int,
return corrected
def parse_lines(image: np.ndarray, show=False) -> list[dict[str, list[int]]]:
def parse_lines(image: np.ndarray, show=False) -> list[dict[str, float]]:
image = preprocess(image)
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))
# image = cv2.dilate(image, kernel, iterations=4)
@ -218,7 +218,7 @@ def parse_lines(image: np.ndarray, show=False) -> list[dict[str, list[int]]]:
horizontal_endpoints = detect_endpoints(horizontal_line_img, is_horizontal=True)
vertical_endpoints = detect_endpoints(vertical_line_img, is_horizontal=False)
def format_quad(quad: tuple[int, int, int, int], max_x: int, max_y: int) -> tuple[int, int, int, int]:
def format_quad(quad: tuple[int, int, int, int], max_x: int, max_y: int) -> dict[str, float]:
x1, y1, x2, y2 = quad
if x1 > (x2 + 5):
x1, y1, x2, y2 = x2, y2, x1, y1

View File

@ -3,8 +3,8 @@ from operator import itemgetter
from pathlib import Path
from typing import Union
import fitz
from kn_utils.logging import logger
import fitz # type: ignore
from kn_utils.logging import logger # type: ignore
def annotate_pdf(
@ -53,7 +53,7 @@ def mirror_on_x_axis(bbox, page_height):
@singledispatch
def provide_byte_stream(pdf: Union[bytes, Path, str]) -> bytes:
def provide_byte_stream(pdf: Union[bytes, Path, str]) -> None:
pass

View File

@ -1,14 +1,16 @@
import json
from dataclasses import dataclass
from functools import partial
from operator import itemgetter
from typing import Iterable, Tuple
from typing import Tuple, SupportsIndex
import fitz
import fitz # type: ignore
import numpy as np
from funcy import compose, lfilter
from kn_utils.logging import logger
from funcy import compose, lfilter # type: ignore
from kn_utils.logging import logger # type: ignore
from numpy import ndarray as Array
BBoxType = tuple[int | float, int | float, int| float, int | float]
@dataclass
class PageInfo:
@ -24,10 +26,10 @@ class PageInfo:
def transform_image_coordinates_to_pdf_coordinates(
bbox: Iterable[int | float],
bbox: BBoxType,
rotation_matrix: fitz.Matrix,
transformation_matrix: fitz.Matrix,
dpi: int = None,
dpi: int | None = None,
) -> Tuple:
x1, y1, x2, y2 = map(lambda x: (x / dpi) * 72, bbox) if dpi else bbox # Convert to points, can be done before
rect = fitz.Rect(x1, y1, x2, y2)
@ -36,19 +38,23 @@ def transform_image_coordinates_to_pdf_coordinates(
return rect.x0, rect.y0, rect.x1, rect.y1
def rescale_to_pdf(bbox: Iterable[int | float], page_info: PageInfo) -> Iterable[float]:
def rescale_to_pdf(bbox: BBoxType, page_info: PageInfo) -> tuple[float, float, float, float]:
round3 = lambda x: tuple(map(lambda y: round(y, 3), x))
pdf_h, pdf_w = page_info.height, page_info.width
if page_info.rotation in {90, 270}:
pdf_h, pdf_w = pdf_w, pdf_h
pix_h, pix_w = page_info.image_height, page_info.image_width
ratio_h, ratio_w = pdf_h / pix_h, pdf_w / pix_w
round3 = lambda x: tuple(map(lambda y: round(y, 3), x))
ratio_w, ratio_h, pdf_w, pdf_h, pix_w, pix_h = round3((ratio_w, ratio_h, pdf_w, pdf_h, pix_w, pix_h))
new_bbox = round3((bbox[0] * ratio_w, bbox[1] * ratio_h, bbox[2] * ratio_w, bbox[3] * ratio_h))
return new_bbox
return round3((bbox[0] * ratio_w, bbox[1] * ratio_h, bbox[2] * ratio_w, bbox[3] * ratio_h))
def transform_table_lines_by_page_info(bboxes: dict, offsets: tuple, page_info: PageInfo) -> dict:
transform = partial(rescale_to_pdf, page_info=page_info)
logger.debug(f"{offsets=}")
@ -65,14 +71,11 @@ def transform_table_lines_by_page_info(bboxes: dict, offsets: tuple, page_info:
convert = compose(pack, apply_offsets, transform, unpack)
table_lines = bboxes.get("tableLines", [])
transformed_lines = list(map(convert, table_lines))
bboxes["tableLines"] = transformed_lines # lfilter(lambda b: b['y1']==b['y2'], transformed_lines)
import json
bboxes["tableLines"] = list(map(convert, table_lines))
for i in range(len(table_lines)):
logger.debug(json.dumps(table_lines[i], indent=4))
logger.debug(json.dumps(transformed_lines[i], indent=4))
logger.debug("")
logger.debug(json.dumps(bboxes["tableLines"][i], indent=4))
return bboxes
@ -92,7 +95,7 @@ def extract_images_from_pdf(
boxes = page_dict["boxes"]
boxes = filter(lambda box_obj: box_obj["label"] == "table", boxes)
page = fh[page_num]
page: fitz.Page = fh[page_num]
page.wrap_contents()
page_image = page.get_pixmap(dpi=200)
@ -101,7 +104,8 @@ def extract_images_from_pdf(
page.rotation_matrix,
page.transformation_matrix,
dpi,
*page.rect[-2:],
page.rect[-2],
page.rect[-1],
page_image.w,
page_image.h,
page.rotation,

View File

@ -10,7 +10,7 @@ def remove_overlapping(rectangles: Iterable[Rectangle]) -> List[Rectangle]:
def overlap(a: Rectangle, rect2: Rectangle) -> float:
return a.intersection(rect2) > 0
def does_not_overlap(rect: Rectangle, rectangles: Iterable[Rectangle]) -> list:
def does_not_overlap(rect: Rectangle, rectangles: Iterable[Rectangle]) -> bool:
return not any(overlap(rect, rect2) for rect2 in rectangles if not rect == rect2)
rectangles = list(filter(partial(does_not_overlap, rectangles=rectangles), rectangles))

View File

@ -2,7 +2,7 @@ from json import dumps
from typing import Iterable
import numpy as np
from funcy import identity
from funcy import identity # type: ignore
class Rectangle:

View File

@ -1,6 +1,6 @@
import os
from pyinfra.config.loader import load_settings
from pyinfra.config.loader import load_settings # type: ignore
from cv_analysis.config import get_config
from cv_analysis.utils.display import save_image