diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index d852e8d..7d26b29 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -3,12 +3,12 @@ from _operator import itemgetter from functools import partial, lru_cache from itertools import chain, filterfalse from operator import itemgetter -from typing import List, Union +from typing import List, Union, Iterable import fitz import numpy as np from PIL import Image -from funcy import merge, compose, rcompose, keep +from funcy import merge, compose, rcompose, keep, lkeep from pymonad.either import Right, Left, Either from pymonad.tools import curry, identity @@ -24,6 +24,11 @@ from image_prediction.utils.generic import bottom, left, right logger = get_logger() +Doc = fitz.fitz.Document +Pag = fitz.fitz.Page +Img = Image.Image + + class ParsablePDFImageExtractor(ImageExtractor): def __init__(self, verbose=False, tolerance=0): """ @@ -33,7 +38,7 @@ class ParsablePDFImageExtractor(ImageExtractor): tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched together """ - self.doc: Union[fitz.fitz.Document, None] = None + self.doc: Union[Doc, None] = None self.verbose = verbose self.tolerance = tolerance @@ -46,36 +51,42 @@ class ParsablePDFImageExtractor(ImageExtractor): yield from image_metadata_pairs - def __process_images_on_page(self, page: fitz.fitz.Page): + def __process_images_on_page(self, page: Pag): metadata = extract_valid_metadata(self.doc, page) - maybe_image_metadata_pairs = map(partial(metadatum_to_image_metadata_pair, self.doc), metadata) - image_metadata_pairs = keep(take_right, maybe_image_metadata_pairs) + either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata) + valid_image_metadata_pairs = lkeep(take_good_log_bad, either_image_metadata_pair_or_error_per_image) + valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance) + clear_caches() - image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) + yield from valid_image_metadata_pairs_stitched - yield from image_metadata_pairs + def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either: + return metadatum_to_image_metadata_pair(self.doc, metadatum) -def take_right(pair: Either): - if pair.is_right(): - return pair.either(bottom, identity) - logger.warning(f"Skipping bad image. {pair.either(format_context, bottom)}") +def take_good_log_bad(pair: Either) -> Union[ImageMetadataPair, None]: + return pair.either(log_error_context, identity) -def format_context(context): +def log_error_context(context: dict) -> None: + logger.warning(f"Skipping bad image. {format_context(context)}") + return None + + +def format_context(context: dict) -> str: return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}" -def extract_pages(doc, page_range): +def extract_pages(doc: Doc, page_range: range): page_range = range(page_range.start + 1, page_range.stop + 1) pages = map(doc.load_page, page_range) yield from pages -def validate_image(image: Image.Image) -> Either: +def validate_image(image: Img) -> Either: try: # TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148) image.resize((100, 100)).convert("RGB") @@ -85,7 +96,7 @@ def validate_image(image: Image.Image) -> Either: return Left("Invalid image.") -def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page): +def extract_valid_metadata(doc: Doc, page: Pag) -> List[dict]: return compose( list, partial(add_alpha_channel_info, doc), @@ -94,14 +105,14 @@ def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page): )(page) -def metadatum_to_image_metadata_pair(doc, metadatum: dict) -> Either: - maybe_image = xref_to_maybe_image(doc, metadatum[Info.XREF]).bind(validate_image) - maybe_image_metadata_pair = make_maybe_image_metadata_pair(maybe_image, Right(metadatum)) - return maybe_image_metadata_pair +def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either: + image: Either = xref_to_image(doc, metadatum[Info.XREF]).bind(validate_image) + image_metadata_pair: Either = make_eithered_image_metadata_pair(image, Right(metadatum)) + return image_metadata_pair -def add_alpha_channel_info(doc, metadata): - def add_alpha_value_to_metadatum(metadatum): +def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]: + def add_alpha_value_to_metadatum(metadatum: dict) -> dict: alpha = metadatum_to_alpha_value(metadatum) return {**metadatum, Info.ALPHA: alpha} @@ -111,11 +122,11 @@ def add_alpha_channel_info(doc, metadata): yield from map(add_alpha_value_to_metadatum, metadata) -def filter_valid_metadata(metadata): +def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]: yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata) -def get_metadata_for_images_on_page(page: fitz.Page): +def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]: metadata = map(get_image_metadata, get_image_infos(page)) metadata = add_page_metadata(page, metadata) @@ -123,14 +134,14 @@ def get_metadata_for_images_on_page(page: fitz.Page): @lru_cache(maxsize=None) -def xref_to_maybe_image(doc, xref) -> Either: +def xref_to_image(doc: Doc, xref: int) -> Either: try: return Right(extract_image(doc, xref)) except BadXref: return Left("Bad xref.") -def make_maybe_image_metadata_pair(image: Either, metadata: Either): +def make_eithered_image_metadata_pair(image: Either, metadata: Either) -> Either: """Reference: haskell.org/tutorial/monads.html""" def context(value): @@ -160,11 +171,11 @@ def make_maybe_image_metadata_pair(image: Either, metadata: Either): @curry(2) -def make_image_metadata_pair(image: Image.Image, metadatum: dict) -> ImageMetadataPair: +def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair: return ImageMetadataPair(image, metadatum) -def extract_image(doc, xref) -> Image.Image: +def extract_image(doc: Doc, xref: int) -> Img: try: pixmap = fitz.Pixmap(doc, xref) except ValueError as err: @@ -177,12 +188,12 @@ def extract_image(doc, xref) -> Image.Image: return Image.fromarray(array) -def has_alpha_channel(doc, xref): +def has_alpha_channel(doc: Doc, xref: int): maybe_image = load_image_handle_from_xref(doc, xref) maybe_smask = maybe_image["smask"] if maybe_image else None - if maybe_smask: + if maybe_smask: # Use monad return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) else: try: @@ -192,11 +203,11 @@ def has_alpha_channel(doc, xref): return False -def filter_out_tiny_images(metadata): +def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]: yield from filterfalse(tiny, metadata) -def filter_out_invalid_metadata(metadata): +def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]: def __validate_box(box): try: return validate_box(box) @@ -206,7 +217,7 @@ def filter_out_invalid_metadata(metadata): yield from keep(__validate_box, metadata) -def get_image_metadata(image_info): +def get_image_metadata(image_info: dict) -> dict: xref, coords = itemgetter("xref", "bbox")(image_info) x1, y1, x2, y2 = map(rounder, coords) @@ -226,11 +237,11 @@ def get_image_metadata(image_info): @lru_cache(maxsize=None) -def get_image_infos(page: fitz.Page) -> List[dict]: +def get_image_infos(page: Pag) -> List[dict]: return page.get_image_info(xrefs=True) -def add_page_metadata(page, metadata): +def add_page_metadata(page: Pag, metadata: Iterable[dict]) -> Iterable[dict]: yield from map(partial(merge, get_page_metadata(page)), metadata) @@ -250,11 +261,11 @@ def normalize_channels(array: np.ndarray): @lru_cache(maxsize=None) -def load_image_handle_from_xref(doc, xref): +def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TODO: use Monad return doc.extract_image(xref) -def get_page_metadata(page): +def get_page_metadata(page: Pag) -> dict: page_width, page_height = map(rounder, page.mediabox_size) return { @@ -267,14 +278,14 @@ def get_page_metadata(page): rounder = rcompose(round, int) -def tiny(metadatum): +def tiny(metadatum: dict) -> bool: return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4 -def clear_caches(): +def clear_caches() -> None: get_image_infos.cache_clear() load_image_handle_from_xref.cache_clear() - xref_to_maybe_image.cache_clear() + xref_to_image.cache_clear() atexit.register(clear_caches) diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index 0e8b0c9..b27744b 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -18,7 +18,7 @@ from image_prediction.image_extractor.extractors.parsable import ( get_image_infos, ParsablePDFImageExtractor, extract_valid_metadata, - xref_to_maybe_image, + xref_to_image, extract_image, ) from image_prediction.info import Info @@ -98,4 +98,4 @@ def test_bad_xref_handling(bad_xref_pdf, dvc_test_data): with pytest.raises(BadXref): extract_image(doc, xref) - assert xref_to_maybe_image(doc, xref).is_left() + assert xref_to_image(doc, xref).is_left()