Refactoring

Make alpha channel check monadic to streamline error handling
This commit is contained in:
Matthias Bisping 2023-02-09 15:35:37 +01:00
parent e99e97e23f
commit 970275b257
2 changed files with 117 additions and 43 deletions

View File

@ -1,14 +1,13 @@
import atexit
from _operator import itemgetter
from functools import partial, lru_cache, singledispatch
from functools import partial, lru_cache
from itertools import chain, filterfalse
from operator import itemgetter
from operator import itemgetter, attrgetter
from typing import List, Union, Iterable, Any
import fitz
import numpy as np
from PIL import Image
from funcy import merge, compose, rcompose, keep, lkeep
from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial
from pymonad.either import Right, Left, Either
from pymonad.tools import curry, identity
@ -19,11 +18,10 @@ from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import validate_box
from image_prediction.utils import get_logger
from image_prediction.utils.generic import bottom, left, right, lift
from image_prediction.utils.generic import bottom, left, right, lift, wrap_right
logger = get_logger()
Doc = fitz.fitz.Document
Pge = fitz.fitz.Page
Img = Image.Image
@ -56,7 +54,9 @@ class ParsablePDFImageExtractor(ImageExtractor):
metadata = extract_valid_metadata(self.doc, page)
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
valid_image_metadata_pairs = lkeep(take_good_log_bad, either_image_metadata_pair_or_error_per_image)
valid_image_metadata_pairs = lkeep(
rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image
)
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
clear_caches()
@ -67,20 +67,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
return metadatum_to_image_metadata_pair(self.doc, metadatum)
def take_good_log_bad(pair: Either) -> Union[ImageMetadataPair, None]:
return pair.either(log_error_context, identity)
def log_error_context(context: dict) -> None:
logger.warning(f"Skipping bad image. {format_context(context)}")
return None
def format_context(context: dict) -> str:
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
def extract_pages(doc: Doc, page_range: range):
def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]:
page_range = range(page_range.start + 1, page_range.stop + 1)
pages = map(doc.load_page, page_range)
@ -106,6 +93,19 @@ def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
)(page)
def take_good_log_bad(item: Either, log_formatter=identity) -> Any:
return item.either(rpartial(log_error_context, log_formatter), identity)
def format_context(context: dict) -> str:
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
def log_error_context(context: dict, formatter=identity) -> None:
logger.warning(f"Skipping bad image. {formatter(context)}")
return None
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
@ -138,6 +138,7 @@ def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
@lru_cache(maxsize=None)
@curry(2)
def eith_extract_image(doc: Doc, xref: int) -> Either:
try:
return Right(extract_image(doc, xref))
@ -184,7 +185,6 @@ def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
return ImageMetadataPair(image, metadatum)
@singledispatch
def extract_image(doc: Doc, xref: int) -> Any:
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
@ -199,26 +199,44 @@ def extract_pixmap(doc: Doc, xref: int) -> Pxm:
try:
return fitz.Pixmap(doc, xref)
except ValueError as err:
msg = f"Xref {xref} is invalid, skipping extraction."
logger.debug(msg)
msg = f"Cross reference {xref} is invalid, skipping extraction."
logger.error(err)
logger.trace(msg)
raise BadXref(msg) from err
def has_alpha_channel(doc: Doc, xref: int):
def has_alpha_channel(doc: Doc, xref: int) -> bool:
maybe_image = load_image_handle_from_xref(doc, xref)
maybe_smask = maybe_image["smask"] if maybe_image else None
_get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc)
_extract_pixmap = wrap_right(extract_pixmap)(doc)
if maybe_smask: # TODO: Use monad.
return any(
[load_image_handle_from_xref(doc, maybe_smask) is not None, bool(extract_pixmap(doc, maybe_smask).alpha)]
)
else:
try:
return bool(fitz.Pixmap(doc, xref).alpha)
except ValueError:
logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '<no title>')}.")
return False
def get_soft_mask_reference(cross_reference: int) -> Either:
def error(value) -> str:
return f"Invalid soft mask {value} for cross reference {cross_reference}."
logger.trace(f"Getting soft mask handle for cross reference {cross_reference}.")
pass_on_if_not_none = iffy(notnone, right(identity), left(error))
return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none)
def mask_exists(soft_mask_reference: int) -> Either:
logger.trace(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.")
return _get_image_handle(soft_mask_reference).then(notnone)
def image_has_alpha_channel(reference: int) -> Either:
logger.trace(f"Checking if image with reference {reference} has alpha channel.")
return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool)
cross_reference = Right(xref)
soft_mask_reference = cross_reference.bind(get_soft_mask_reference)
return any(
take_good_log_bad(reference.bind(check))
for reference, check in [
(soft_mask_reference, mask_exists),
(soft_mask_reference, image_has_alpha_channel),
(cross_reference, image_has_alpha_channel),
]
)
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
@ -279,7 +297,7 @@ def normalize_channels(array: np.ndarray):
@lru_cache(maxsize=None)
def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TODO: Use monad.
def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]:
return doc.extract_image(xref)
@ -302,7 +320,7 @@ def tiny(metadatum: dict) -> bool:
def clear_caches() -> None:
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
get_image_handle.cache_clear()
eith_extract_image.cache_clear()

View File

@ -1,7 +1,15 @@
from functools import wraps
from inspect import signature
from itertools import starmap
from typing import Callable
from funcy import iterate, first, curry, map
from pymonad.either import Left, Right
from pymonad.either import Left, Right, Either
from pymonad.tools import curry as pmcurry
from image_prediction.utils import get_logger
logger = get_logger()
def until(cond, func, *args, **kwargs):
@ -17,12 +25,60 @@ def starlift(fn):
def bottom(*args, **kwargs):
return None
return False
def top(*args, **kwargs):
return True
def left(fn):
return lambda x: Left(fn(x))
@wraps(fn)
def inner(x):
return Left(fn(x))
return inner
def right(fn):
return lambda x: Right(fn(x))
@wraps(fn)
def inner(x):
return Right(fn(x))
return inner
def wrap_left(fn, success_condition=top, error_message=None) -> Callable:
return wrap_either(Left, Right, success_condition=success_condition, error_message=error_message)(fn)
def wrap_right(fn, success_condition=top, error_message=None) -> Callable:
return wrap_either(Right, Left, success_condition=success_condition, error_message=error_message)(fn)
def wrap_either(success_type, failure_type, success_condition=top, error_message=None) -> Callable:
@wraps(wrap_either)
def wrapper(fn) -> Callable:
n_params = len(signature(fn).parameters)
@pmcurry(n_params)
@wraps(fn)
def wrapper(*args, **kwargs) -> Either:
try:
result = fn(*args, **kwargs)
if success_condition(result):
return success_type(result)
else:
return failure_type({"error": error_message, "result": result})
except Exception as err:
logger.error(err)
return failure_type({"error": error_message or err, "result": Void})
return wrapper
return wrapper
class Void:
pass