Refactoring

Make alpha channel check monadic to streamline error handling
This commit is contained in:
Matthias Bisping 2023-02-09 15:35:37 +01:00
parent e99e97e23f
commit 970275b257
2 changed files with 117 additions and 43 deletions

View File

@ -1,14 +1,13 @@
import atexit import atexit
from _operator import itemgetter from functools import partial, lru_cache
from functools import partial, lru_cache, singledispatch
from itertools import chain, filterfalse from itertools import chain, filterfalse
from operator import itemgetter from operator import itemgetter, attrgetter
from typing import List, Union, Iterable, Any from typing import List, Union, Iterable, Any
import fitz import fitz
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from funcy import merge, compose, rcompose, keep, lkeep from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial
from pymonad.either import Right, Left, Either from pymonad.either import Right, Left, Either
from pymonad.tools import curry, identity from pymonad.tools import curry, identity
@ -19,11 +18,10 @@ from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import validate_box from image_prediction.stitching.utils import validate_box
from image_prediction.utils import get_logger from image_prediction.utils import get_logger
from image_prediction.utils.generic import bottom, left, right, lift from image_prediction.utils.generic import bottom, left, right, lift, wrap_right
logger = get_logger() logger = get_logger()
Doc = fitz.fitz.Document Doc = fitz.fitz.Document
Pge = fitz.fitz.Page Pge = fitz.fitz.Page
Img = Image.Image Img = Image.Image
@ -56,7 +54,9 @@ class ParsablePDFImageExtractor(ImageExtractor):
metadata = extract_valid_metadata(self.doc, page) metadata = extract_valid_metadata(self.doc, page)
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata) either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
valid_image_metadata_pairs = lkeep(take_good_log_bad, either_image_metadata_pair_or_error_per_image) valid_image_metadata_pairs = lkeep(
rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image
)
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance) valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
clear_caches() clear_caches()
@ -67,20 +67,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
return metadatum_to_image_metadata_pair(self.doc, metadatum) return metadatum_to_image_metadata_pair(self.doc, metadatum)
def take_good_log_bad(pair: Either) -> Union[ImageMetadataPair, None]: def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]:
return pair.either(log_error_context, identity)
def log_error_context(context: dict) -> None:
logger.warning(f"Skipping bad image. {format_context(context)}")
return None
def format_context(context: dict) -> str:
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
def extract_pages(doc: Doc, page_range: range):
page_range = range(page_range.start + 1, page_range.stop + 1) page_range = range(page_range.start + 1, page_range.stop + 1)
pages = map(doc.load_page, page_range) pages = map(doc.load_page, page_range)
@ -106,6 +93,19 @@ def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
)(page) )(page)
def take_good_log_bad(item: Either, log_formatter=identity) -> Any:
return item.either(rpartial(log_error_context, log_formatter), identity)
def format_context(context: dict) -> str:
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
def log_error_context(context: dict, formatter=identity) -> None:
logger.warning(f"Skipping bad image. {formatter(context)}")
return None
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either: def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image) image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum)) image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
@ -138,6 +138,7 @@ def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
@curry(2)
def eith_extract_image(doc: Doc, xref: int) -> Either: def eith_extract_image(doc: Doc, xref: int) -> Either:
try: try:
return Right(extract_image(doc, xref)) return Right(extract_image(doc, xref))
@ -184,7 +185,6 @@ def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
return ImageMetadataPair(image, metadatum) return ImageMetadataPair(image, metadatum)
@singledispatch
def extract_image(doc: Doc, xref: int) -> Any: def extract_image(doc: Doc, xref: int) -> Any:
return compose(pixmap_to_image, extract_pixmap)(doc, xref) return compose(pixmap_to_image, extract_pixmap)(doc, xref)
@ -199,26 +199,44 @@ def extract_pixmap(doc: Doc, xref: int) -> Pxm:
try: try:
return fitz.Pixmap(doc, xref) return fitz.Pixmap(doc, xref)
except ValueError as err: except ValueError as err:
msg = f"Xref {xref} is invalid, skipping extraction." msg = f"Cross reference {xref} is invalid, skipping extraction."
logger.debug(msg) logger.error(err)
logger.trace(msg)
raise BadXref(msg) from err raise BadXref(msg) from err
def has_alpha_channel(doc: Doc, xref: int): def has_alpha_channel(doc: Doc, xref: int) -> bool:
maybe_image = load_image_handle_from_xref(doc, xref) _get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc)
maybe_smask = maybe_image["smask"] if maybe_image else None _extract_pixmap = wrap_right(extract_pixmap)(doc)
if maybe_smask: # TODO: Use monad. def get_soft_mask_reference(cross_reference: int) -> Either:
return any( def error(value) -> str:
[load_image_handle_from_xref(doc, maybe_smask) is not None, bool(extract_pixmap(doc, maybe_smask).alpha)] return f"Invalid soft mask {value} for cross reference {cross_reference}."
)
else: logger.trace(f"Getting soft mask handle for cross reference {cross_reference}.")
try: pass_on_if_not_none = iffy(notnone, right(identity), left(error))
return bool(fitz.Pixmap(doc, xref).alpha) return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none)
except ValueError:
logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '<no title>')}.") def mask_exists(soft_mask_reference: int) -> Either:
return False logger.trace(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.")
return _get_image_handle(soft_mask_reference).then(notnone)
def image_has_alpha_channel(reference: int) -> Either:
logger.trace(f"Checking if image with reference {reference} has alpha channel.")
return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool)
cross_reference = Right(xref)
soft_mask_reference = cross_reference.bind(get_soft_mask_reference)
return any(
take_good_log_bad(reference.bind(check))
for reference, check in [
(soft_mask_reference, mask_exists),
(soft_mask_reference, image_has_alpha_channel),
(cross_reference, image_has_alpha_channel),
]
)
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]: def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
@ -279,7 +297,7 @@ def normalize_channels(array: np.ndarray):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TODO: Use monad. def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]:
return doc.extract_image(xref) return doc.extract_image(xref)
@ -302,7 +320,7 @@ def tiny(metadatum: dict) -> bool:
def clear_caches() -> None: def clear_caches() -> None:
get_image_infos.cache_clear() get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear() get_image_handle.cache_clear()
eith_extract_image.cache_clear() eith_extract_image.cache_clear()

View File

@ -1,7 +1,15 @@
from functools import wraps
from inspect import signature
from itertools import starmap from itertools import starmap
from typing import Callable
from funcy import iterate, first, curry, map from funcy import iterate, first, curry, map
from pymonad.either import Left, Right from pymonad.either import Left, Right, Either
from pymonad.tools import curry as pmcurry
from image_prediction.utils import get_logger
logger = get_logger()
def until(cond, func, *args, **kwargs): def until(cond, func, *args, **kwargs):
@ -17,12 +25,60 @@ def starlift(fn):
def bottom(*args, **kwargs): def bottom(*args, **kwargs):
return None return False
def top(*args, **kwargs):
return True
def left(fn): def left(fn):
return lambda x: Left(fn(x)) @wraps(fn)
def inner(x):
return Left(fn(x))
return inner
def right(fn): def right(fn):
return lambda x: Right(fn(x)) @wraps(fn)
def inner(x):
return Right(fn(x))
return inner
def wrap_left(fn, success_condition=top, error_message=None) -> Callable:
return wrap_either(Left, Right, success_condition=success_condition, error_message=error_message)(fn)
def wrap_right(fn, success_condition=top, error_message=None) -> Callable:
return wrap_either(Right, Left, success_condition=success_condition, error_message=error_message)(fn)
def wrap_either(success_type, failure_type, success_condition=top, error_message=None) -> Callable:
@wraps(wrap_either)
def wrapper(fn) -> Callable:
n_params = len(signature(fn).parameters)
@pmcurry(n_params)
@wraps(fn)
def wrapper(*args, **kwargs) -> Either:
try:
result = fn(*args, **kwargs)
if success_condition(result):
return success_type(result)
else:
return failure_type({"error": error_message, "result": result})
except Exception as err:
logger.error(err)
return failure_type({"error": error_message or err, "result": Void})
return wrapper
return wrapper
class Void:
pass