Refactoring
Make alpha channel check monadic to streamline error handling
This commit is contained in:
parent
e99e97e23f
commit
970275b257
@ -1,14 +1,13 @@
|
||||
import atexit
|
||||
from _operator import itemgetter
|
||||
from functools import partial, lru_cache, singledispatch
|
||||
from functools import partial, lru_cache
|
||||
from itertools import chain, filterfalse
|
||||
from operator import itemgetter
|
||||
from operator import itemgetter, attrgetter
|
||||
from typing import List, Union, Iterable, Any
|
||||
|
||||
import fitz
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from funcy import merge, compose, rcompose, keep, lkeep
|
||||
from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial
|
||||
from pymonad.either import Right, Left, Either
|
||||
from pymonad.tools import curry, identity
|
||||
|
||||
@ -19,11 +18,10 @@ from image_prediction.info import Info
|
||||
from image_prediction.stitching.stitching import stitch_pairs
|
||||
from image_prediction.stitching.utils import validate_box
|
||||
from image_prediction.utils import get_logger
|
||||
from image_prediction.utils.generic import bottom, left, right, lift
|
||||
from image_prediction.utils.generic import bottom, left, right, lift, wrap_right
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
Doc = fitz.fitz.Document
|
||||
Pge = fitz.fitz.Page
|
||||
Img = Image.Image
|
||||
@ -56,7 +54,9 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
metadata = extract_valid_metadata(self.doc, page)
|
||||
|
||||
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
|
||||
valid_image_metadata_pairs = lkeep(take_good_log_bad, either_image_metadata_pair_or_error_per_image)
|
||||
valid_image_metadata_pairs = lkeep(
|
||||
rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image
|
||||
)
|
||||
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
|
||||
|
||||
clear_caches()
|
||||
@ -67,20 +67,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
return metadatum_to_image_metadata_pair(self.doc, metadatum)
|
||||
|
||||
|
||||
def take_good_log_bad(pair: Either) -> Union[ImageMetadataPair, None]:
|
||||
return pair.either(log_error_context, identity)
|
||||
|
||||
|
||||
def log_error_context(context: dict) -> None:
|
||||
logger.warning(f"Skipping bad image. {format_context(context)}")
|
||||
return None
|
||||
|
||||
|
||||
def format_context(context: dict) -> str:
|
||||
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
|
||||
|
||||
|
||||
def extract_pages(doc: Doc, page_range: range):
|
||||
def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]:
|
||||
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||
pages = map(doc.load_page, page_range)
|
||||
|
||||
@ -106,6 +93,19 @@ def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
|
||||
)(page)
|
||||
|
||||
|
||||
def take_good_log_bad(item: Either, log_formatter=identity) -> Any:
|
||||
return item.either(rpartial(log_error_context, log_formatter), identity)
|
||||
|
||||
|
||||
def format_context(context: dict) -> str:
|
||||
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
|
||||
|
||||
|
||||
def log_error_context(context: dict, formatter=identity) -> None:
|
||||
logger.warning(f"Skipping bad image. {formatter(context)}")
|
||||
return None
|
||||
|
||||
|
||||
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
|
||||
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
||||
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
|
||||
@ -138,6 +138,7 @@ def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@curry(2)
|
||||
def eith_extract_image(doc: Doc, xref: int) -> Either:
|
||||
try:
|
||||
return Right(extract_image(doc, xref))
|
||||
@ -184,7 +185,6 @@ def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
|
||||
return ImageMetadataPair(image, metadatum)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def extract_image(doc: Doc, xref: int) -> Any:
|
||||
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
|
||||
|
||||
@ -199,26 +199,44 @@ def extract_pixmap(doc: Doc, xref: int) -> Pxm:
|
||||
try:
|
||||
return fitz.Pixmap(doc, xref)
|
||||
except ValueError as err:
|
||||
msg = f"Xref {xref} is invalid, skipping extraction."
|
||||
logger.debug(msg)
|
||||
msg = f"Cross reference {xref} is invalid, skipping extraction."
|
||||
logger.error(err)
|
||||
logger.trace(msg)
|
||||
raise BadXref(msg) from err
|
||||
|
||||
|
||||
def has_alpha_channel(doc: Doc, xref: int):
|
||||
def has_alpha_channel(doc: Doc, xref: int) -> bool:
|
||||
|
||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||
maybe_smask = maybe_image["smask"] if maybe_image else None
|
||||
_get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc)
|
||||
_extract_pixmap = wrap_right(extract_pixmap)(doc)
|
||||
|
||||
if maybe_smask: # TODO: Use monad.
|
||||
return any(
|
||||
[load_image_handle_from_xref(doc, maybe_smask) is not None, bool(extract_pixmap(doc, maybe_smask).alpha)]
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return bool(fitz.Pixmap(doc, xref).alpha)
|
||||
except ValueError:
|
||||
logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '<no title>')}.")
|
||||
return False
|
||||
def get_soft_mask_reference(cross_reference: int) -> Either:
|
||||
def error(value) -> str:
|
||||
return f"Invalid soft mask {value} for cross reference {cross_reference}."
|
||||
|
||||
logger.trace(f"Getting soft mask handle for cross reference {cross_reference}.")
|
||||
pass_on_if_not_none = iffy(notnone, right(identity), left(error))
|
||||
return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none)
|
||||
|
||||
def mask_exists(soft_mask_reference: int) -> Either:
|
||||
logger.trace(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.")
|
||||
return _get_image_handle(soft_mask_reference).then(notnone)
|
||||
|
||||
def image_has_alpha_channel(reference: int) -> Either:
|
||||
logger.trace(f"Checking if image with reference {reference} has alpha channel.")
|
||||
return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool)
|
||||
|
||||
cross_reference = Right(xref)
|
||||
soft_mask_reference = cross_reference.bind(get_soft_mask_reference)
|
||||
|
||||
return any(
|
||||
take_good_log_bad(reference.bind(check))
|
||||
for reference, check in [
|
||||
(soft_mask_reference, mask_exists),
|
||||
(soft_mask_reference, image_has_alpha_channel),
|
||||
(cross_reference, image_has_alpha_channel),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
@ -279,7 +297,7 @@ def normalize_channels(array: np.ndarray):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TODO: Use monad.
|
||||
def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]:
|
||||
return doc.extract_image(xref)
|
||||
|
||||
|
||||
@ -302,7 +320,7 @@ def tiny(metadatum: dict) -> bool:
|
||||
|
||||
def clear_caches() -> None:
|
||||
get_image_infos.cache_clear()
|
||||
load_image_handle_from_xref.cache_clear()
|
||||
get_image_handle.cache_clear()
|
||||
eith_extract_image.cache_clear()
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,15 @@
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
from itertools import starmap
|
||||
from typing import Callable
|
||||
|
||||
from funcy import iterate, first, curry, map
|
||||
from pymonad.either import Left, Right
|
||||
from pymonad.either import Left, Right, Either
|
||||
from pymonad.tools import curry as pmcurry
|
||||
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def until(cond, func, *args, **kwargs):
|
||||
@ -17,12 +25,60 @@ def starlift(fn):
|
||||
|
||||
|
||||
def bottom(*args, **kwargs):
|
||||
return None
|
||||
return False
|
||||
|
||||
|
||||
def top(*args, **kwargs):
|
||||
return True
|
||||
|
||||
|
||||
def left(fn):
|
||||
return lambda x: Left(fn(x))
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
return Left(fn(x))
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def right(fn):
|
||||
return lambda x: Right(fn(x))
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
return Right(fn(x))
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def wrap_left(fn, success_condition=top, error_message=None) -> Callable:
|
||||
return wrap_either(Left, Right, success_condition=success_condition, error_message=error_message)(fn)
|
||||
|
||||
|
||||
def wrap_right(fn, success_condition=top, error_message=None) -> Callable:
|
||||
return wrap_either(Right, Left, success_condition=success_condition, error_message=error_message)(fn)
|
||||
|
||||
|
||||
def wrap_either(success_type, failure_type, success_condition=top, error_message=None) -> Callable:
|
||||
@wraps(wrap_either)
|
||||
def wrapper(fn) -> Callable:
|
||||
|
||||
n_params = len(signature(fn).parameters)
|
||||
|
||||
@pmcurry(n_params)
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs) -> Either:
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
if success_condition(result):
|
||||
return success_type(result)
|
||||
else:
|
||||
return failure_type({"error": error_message, "result": result})
|
||||
except Exception as err:
|
||||
logger.error(err)
|
||||
return failure_type({"error": error_message or err, "result": Void})
|
||||
|
||||
return wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Void:
|
||||
pass
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user