diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index c91d67b..1b88f0d 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -28,3 +28,7 @@ class IncorrectInstantiation(RuntimeError): class IntentionalTestException(RuntimeError): pass + + +class InvalidBox(Exception): + pass diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 2d79fac..a76e071 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -11,12 +11,20 @@ from tqdm import tqdm from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.info import Info from image_prediction.stitching.stitching import stitch_pairs +from image_prediction.stitching.utils import validate_box_coords, validate_box_size class ParsablePDFImageExtractor(ImageExtractor): - def __init__(self, verbose=False): + def __init__(self, verbose=False, tolerance=0): + """ + + Args: + verbose: Whether to show progressbar + tolerance: The tolerance in pixels for the distance images beyond which they will not be stitched together + """ self.doc: fitz.fitz.Document = None self.verbose = verbose + self.tolerance = tolerance def extract(self, pdf: bytes, page_range: range = None): self.doc = fitz.Document(stream=pdf) @@ -30,7 +38,7 @@ class ParsablePDFImageExtractor(ImageExtractor): ) ) - return image_metadata_pairs + yield from image_metadata_pairs def __process_images_on_page(self, page: fitz.fitz.Page): images = get_images_on_page(self.doc, page) @@ -40,14 +48,15 @@ class ParsablePDFImageExtractor(ImageExtractor): image_metadata_pairs = starmap( ImageMetadataPair, filter(compose(all, curry(map)(truth)), zip(images, metadata)) ) - image_metadata_pairs = stitch_pairs(list(image_metadata_pairs)) + image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) - return image_metadata_pairs + yield from image_metadata_pairs def extract_pages(doc, page_range): page_range = range(page_range.start + 1, page_range.stop + 1) pages = map(doc.load_page, page_range) + return pages @@ -55,14 +64,27 @@ def get_images_on_page(doc, page: fitz.Page): image_infos = get_image_infos(page) xrefs = map(itemgetter("xref"), image_infos) images = map(partial(load_image_from_xref, doc), xrefs) + return images def get_metadata_for_images_on_page(page: fitz.Page): image_infos = get_image_infos(page) metadata = map(get_image_metadata, image_infos) + metadata = validate_coords_and_passthrough(metadata) + metadata = filter(tiny, metadata) + metadata = validate_size_and_passthrough(metadata) metadata = map(partial(merge, get_page_metadata(page)), metadata) - return metadata + + yield from metadata + + +def validate_coords_and_passthrough(metadata): + yield from map(validate_box_coords, metadata) + + +def validate_size_and_passthrough(metadata): + yield from map(validate_box_size, metadata) def load_image_from_xref(doc, xref): @@ -77,7 +99,10 @@ def get_image_infos(page: fitz.Page): def get_image_metadata(image_info): x1, y1, x2, y2 = map(rounder, image_info["bbox"]) - width, height = itemgetter("width", "height")(image_info) + + width = abs(x2 - x1) + height = abs(y2 - y1) + return { Info.WIDTH: width, Info.HEIGHT: height, @@ -88,6 +113,10 @@ def get_image_metadata(image_info): } +def tiny(metadata): + return metadata[Info.WIDTH] * metadata[Info.HEIGHT] + + def get_page_metadata(page): page_width, page_height = map(rounder, page.mediabox_size) diff --git a/image_prediction/stitching/utils.py b/image_prediction/stitching/utils.py index 67a7ebe..e5bed7b 100644 --- a/image_prediction/stitching/utils.py +++ b/image_prediction/stitching/utils.py @@ -1,5 +1,8 @@ +import json from itertools import chain +from image_prediction.exceptions import InvalidBox +from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.info import Info @@ -31,5 +34,34 @@ def make_length_getter(dim): def validate_box(box): - assert box[Info.X2] - box[Info.X1] == box[Info.WIDTH] - assert box[Info.Y2] - box[Info.Y1] == box[Info.HEIGHT] + validate_box_coords(box) + validate_box_size(box) + return box + + +def validate_box_coords(box): + + x_diff = box[Info.WIDTH] - (box[Info.X2] - box[Info.X1]) + y_diff = box[Info.HEIGHT] - (box[Info.Y2] - box[Info.Y1]) + + if x_diff: + raise InvalidBox(f"Width and x-coordinates differ by {x_diff} units: {format_box(box)}") + if y_diff: + raise InvalidBox(f"Width and y-coordinates differ by {y_diff} units: {format_box(box)}") + + return box + + +def validate_box_size(box): + + if not box[Info.WIDTH]: + raise InvalidBox(f"Zero width box: {format_box(box)}") + + if not box[Info.HEIGHT]: + raise InvalidBox(f"Zero height box: {format_box(box)}") + + return box + + +def format_box(box): + return json.dumps(EnumFormatter()(box), indent=2) diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index a44c572..c2b4bb0 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -35,7 +35,7 @@ def process_pdf(pipeline, pdf_path, page_range=None): def main(args): - pipeline = load_pipeline(verbose=True) + pipeline = load_pipeline(verbose=True, tolerance=3) if os.path.isfile(args.input): pdf_paths = [args.input]