From 970275b25708c05e4fbe78b52aa70d791d5ff17a Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 9 Feb 2023 15:35:37 +0100 Subject: [PATCH] Refactoring Make alpha channel check monadic to streamline error handling --- .../image_extractor/extractors/parsable.py | 96 +++++++++++-------- image_prediction/utils/generic.py | 64 ++++++++++++- 2 files changed, 117 insertions(+), 43 deletions(-) diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index d3a5adf..7136e8c 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,14 +1,13 @@ import atexit -from _operator import itemgetter -from functools import partial, lru_cache, singledispatch +from functools import partial, lru_cache from itertools import chain, filterfalse -from operator import itemgetter +from operator import itemgetter, attrgetter from typing import List, Union, Iterable, Any import fitz import numpy as np from PIL import Image -from funcy import merge, compose, rcompose, keep, lkeep +from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial from pymonad.either import Right, Left, Either from pymonad.tools import curry, identity @@ -19,11 +18,10 @@ from image_prediction.info import Info from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.utils import validate_box from image_prediction.utils import get_logger -from image_prediction.utils.generic import bottom, left, right, lift +from image_prediction.utils.generic import bottom, left, right, lift, wrap_right logger = get_logger() - Doc = fitz.fitz.Document Pge = fitz.fitz.Page Img = Image.Image @@ -56,7 +54,9 @@ class ParsablePDFImageExtractor(ImageExtractor): metadata = extract_valid_metadata(self.doc, page) either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata) - valid_image_metadata_pairs = lkeep(take_good_log_bad, either_image_metadata_pair_or_error_per_image) + valid_image_metadata_pairs = lkeep( + rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image + ) valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance) clear_caches() @@ -67,20 +67,7 @@ class ParsablePDFImageExtractor(ImageExtractor): return metadatum_to_image_metadata_pair(self.doc, metadatum) -def take_good_log_bad(pair: Either) -> Union[ImageMetadataPair, None]: - return pair.either(log_error_context, identity) - - -def log_error_context(context: dict) -> None: - logger.warning(f"Skipping bad image. {format_context(context)}") - return None - - -def format_context(context: dict) -> str: - return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}" - - -def extract_pages(doc: Doc, page_range: range): +def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]: page_range = range(page_range.start + 1, page_range.stop + 1) pages = map(doc.load_page, page_range) @@ -106,6 +93,19 @@ def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]: )(page) +def take_good_log_bad(item: Either, log_formatter=identity) -> Any: + return item.either(rpartial(log_error_context, log_formatter), identity) + + +def format_context(context: dict) -> str: + return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}" + + +def log_error_context(context: dict, formatter=identity) -> None: + logger.warning(f"Skipping bad image. {formatter(context)}") + return None + + def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either: image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image) image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum)) @@ -138,6 +138,7 @@ def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]: @lru_cache(maxsize=None) +@curry(2) def eith_extract_image(doc: Doc, xref: int) -> Either: try: return Right(extract_image(doc, xref)) @@ -184,7 +185,6 @@ def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair: return ImageMetadataPair(image, metadatum) -@singledispatch def extract_image(doc: Doc, xref: int) -> Any: return compose(pixmap_to_image, extract_pixmap)(doc, xref) @@ -199,26 +199,44 @@ def extract_pixmap(doc: Doc, xref: int) -> Pxm: try: return fitz.Pixmap(doc, xref) except ValueError as err: - msg = f"Xref {xref} is invalid, skipping extraction." - logger.debug(msg) + msg = f"Cross reference {xref} is invalid, skipping extraction." + logger.error(err) + logger.trace(msg) raise BadXref(msg) from err -def has_alpha_channel(doc: Doc, xref: int): +def has_alpha_channel(doc: Doc, xref: int) -> bool: - maybe_image = load_image_handle_from_xref(doc, xref) - maybe_smask = maybe_image["smask"] if maybe_image else None + _get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc) + _extract_pixmap = wrap_right(extract_pixmap)(doc) - if maybe_smask: # TODO: Use monad. - return any( - [load_image_handle_from_xref(doc, maybe_smask) is not None, bool(extract_pixmap(doc, maybe_smask).alpha)] - ) - else: - try: - return bool(fitz.Pixmap(doc, xref).alpha) - except ValueError: - logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '')}.") - return False + def get_soft_mask_reference(cross_reference: int) -> Either: + def error(value) -> str: + return f"Invalid soft mask {value} for cross reference {cross_reference}." + + logger.trace(f"Getting soft mask handle for cross reference {cross_reference}.") + pass_on_if_not_none = iffy(notnone, right(identity), left(error)) + return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none) + + def mask_exists(soft_mask_reference: int) -> Either: + logger.trace(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.") + return _get_image_handle(soft_mask_reference).then(notnone) + + def image_has_alpha_channel(reference: int) -> Either: + logger.trace(f"Checking if image with reference {reference} has alpha channel.") + return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool) + + cross_reference = Right(xref) + soft_mask_reference = cross_reference.bind(get_soft_mask_reference) + + return any( + take_good_log_bad(reference.bind(check)) + for reference, check in [ + (soft_mask_reference, mask_exists), + (soft_mask_reference, image_has_alpha_channel), + (cross_reference, image_has_alpha_channel), + ] + ) def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]: @@ -279,7 +297,7 @@ def normalize_channels(array: np.ndarray): @lru_cache(maxsize=None) -def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TODO: Use monad. +def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]: return doc.extract_image(xref) @@ -302,7 +320,7 @@ def tiny(metadatum: dict) -> bool: def clear_caches() -> None: get_image_infos.cache_clear() - load_image_handle_from_xref.cache_clear() + get_image_handle.cache_clear() eith_extract_image.cache_clear() diff --git a/image_prediction/utils/generic.py b/image_prediction/utils/generic.py index 39f75cf..ffdf7b7 100644 --- a/image_prediction/utils/generic.py +++ b/image_prediction/utils/generic.py @@ -1,7 +1,15 @@ +from functools import wraps +from inspect import signature from itertools import starmap +from typing import Callable from funcy import iterate, first, curry, map -from pymonad.either import Left, Right +from pymonad.either import Left, Right, Either +from pymonad.tools import curry as pmcurry + +from image_prediction.utils import get_logger + +logger = get_logger() def until(cond, func, *args, **kwargs): @@ -17,12 +25,60 @@ def starlift(fn): def bottom(*args, **kwargs): - return None + return False + + +def top(*args, **kwargs): + return True def left(fn): - return lambda x: Left(fn(x)) + @wraps(fn) + def inner(x): + return Left(fn(x)) + + return inner def right(fn): - return lambda x: Right(fn(x)) + @wraps(fn) + def inner(x): + return Right(fn(x)) + + return inner + + +def wrap_left(fn, success_condition=top, error_message=None) -> Callable: + return wrap_either(Left, Right, success_condition=success_condition, error_message=error_message)(fn) + + +def wrap_right(fn, success_condition=top, error_message=None) -> Callable: + return wrap_either(Right, Left, success_condition=success_condition, error_message=error_message)(fn) + + +def wrap_either(success_type, failure_type, success_condition=top, error_message=None) -> Callable: + @wraps(wrap_either) + def wrapper(fn) -> Callable: + + n_params = len(signature(fn).parameters) + + @pmcurry(n_params) + @wraps(fn) + def wrapper(*args, **kwargs) -> Either: + try: + result = fn(*args, **kwargs) + if success_condition(result): + return success_type(result) + else: + return failure_type({"error": error_message, "result": result}) + except Exception as err: + logger.error(err) + return failure_type({"error": error_message or err, "result": Void}) + + return wrapper + + return wrapper + + +class Void: + pass