Refactoring
Make alpha channel check monadic to streamline error handling
This commit is contained in:
parent
e99e97e23f
commit
970275b257
@ -1,14 +1,13 @@
|
|||||||
import atexit
|
import atexit
|
||||||
from _operator import itemgetter
|
from functools import partial, lru_cache
|
||||||
from functools import partial, lru_cache, singledispatch
|
|
||||||
from itertools import chain, filterfalse
|
from itertools import chain, filterfalse
|
||||||
from operator import itemgetter
|
from operator import itemgetter, attrgetter
|
||||||
from typing import List, Union, Iterable, Any
|
from typing import List, Union, Iterable, Any
|
||||||
|
|
||||||
import fitz
|
import fitz
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import merge, compose, rcompose, keep, lkeep
|
from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial
|
||||||
from pymonad.either import Right, Left, Either
|
from pymonad.either import Right, Left, Either
|
||||||
from pymonad.tools import curry, identity
|
from pymonad.tools import curry, identity
|
||||||
|
|
||||||
@ -19,11 +18,10 @@ from image_prediction.info import Info
|
|||||||
from image_prediction.stitching.stitching import stitch_pairs
|
from image_prediction.stitching.stitching import stitch_pairs
|
||||||
from image_prediction.stitching.utils import validate_box
|
from image_prediction.stitching.utils import validate_box
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
from image_prediction.utils.generic import bottom, left, right, lift
|
from image_prediction.utils.generic import bottom, left, right, lift, wrap_right
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
Doc = fitz.fitz.Document
|
Doc = fitz.fitz.Document
|
||||||
Pge = fitz.fitz.Page
|
Pge = fitz.fitz.Page
|
||||||
Img = Image.Image
|
Img = Image.Image
|
||||||
@ -56,7 +54,9 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
metadata = extract_valid_metadata(self.doc, page)
|
metadata = extract_valid_metadata(self.doc, page)
|
||||||
|
|
||||||
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
|
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
|
||||||
valid_image_metadata_pairs = lkeep(take_good_log_bad, either_image_metadata_pair_or_error_per_image)
|
valid_image_metadata_pairs = lkeep(
|
||||||
|
rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image
|
||||||
|
)
|
||||||
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
|
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
|
||||||
|
|
||||||
clear_caches()
|
clear_caches()
|
||||||
@ -67,20 +67,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
return metadatum_to_image_metadata_pair(self.doc, metadatum)
|
return metadatum_to_image_metadata_pair(self.doc, metadatum)
|
||||||
|
|
||||||
|
|
||||||
def take_good_log_bad(pair: Either) -> Union[ImageMetadataPair, None]:
|
def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]:
|
||||||
return pair.either(log_error_context, identity)
|
|
||||||
|
|
||||||
|
|
||||||
def log_error_context(context: dict) -> None:
|
|
||||||
logger.warning(f"Skipping bad image. {format_context(context)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def format_context(context: dict) -> str:
|
|
||||||
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
|
|
||||||
|
|
||||||
|
|
||||||
def extract_pages(doc: Doc, page_range: range):
|
|
||||||
page_range = range(page_range.start + 1, page_range.stop + 1)
|
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||||
pages = map(doc.load_page, page_range)
|
pages = map(doc.load_page, page_range)
|
||||||
|
|
||||||
@ -106,6 +93,19 @@ def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
|
|||||||
)(page)
|
)(page)
|
||||||
|
|
||||||
|
|
||||||
|
def take_good_log_bad(item: Either, log_formatter=identity) -> Any:
|
||||||
|
return item.either(rpartial(log_error_context, log_formatter), identity)
|
||||||
|
|
||||||
|
|
||||||
|
def format_context(context: dict) -> str:
|
||||||
|
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
|
||||||
|
|
||||||
|
|
||||||
|
def log_error_context(context: dict, formatter=identity) -> None:
|
||||||
|
logger.warning(f"Skipping bad image. {formatter(context)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
|
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
|
||||||
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
||||||
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
|
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
|
||||||
@ -138,6 +138,7 @@ def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
|
@curry(2)
|
||||||
def eith_extract_image(doc: Doc, xref: int) -> Either:
|
def eith_extract_image(doc: Doc, xref: int) -> Either:
|
||||||
try:
|
try:
|
||||||
return Right(extract_image(doc, xref))
|
return Right(extract_image(doc, xref))
|
||||||
@ -184,7 +185,6 @@ def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
|
|||||||
return ImageMetadataPair(image, metadatum)
|
return ImageMetadataPair(image, metadatum)
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
|
||||||
def extract_image(doc: Doc, xref: int) -> Any:
|
def extract_image(doc: Doc, xref: int) -> Any:
|
||||||
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
|
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
|
||||||
|
|
||||||
@ -199,26 +199,44 @@ def extract_pixmap(doc: Doc, xref: int) -> Pxm:
|
|||||||
try:
|
try:
|
||||||
return fitz.Pixmap(doc, xref)
|
return fitz.Pixmap(doc, xref)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
msg = f"Xref {xref} is invalid, skipping extraction."
|
msg = f"Cross reference {xref} is invalid, skipping extraction."
|
||||||
logger.debug(msg)
|
logger.error(err)
|
||||||
|
logger.trace(msg)
|
||||||
raise BadXref(msg) from err
|
raise BadXref(msg) from err
|
||||||
|
|
||||||
|
|
||||||
def has_alpha_channel(doc: Doc, xref: int):
|
def has_alpha_channel(doc: Doc, xref: int) -> bool:
|
||||||
|
|
||||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
_get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc)
|
||||||
maybe_smask = maybe_image["smask"] if maybe_image else None
|
_extract_pixmap = wrap_right(extract_pixmap)(doc)
|
||||||
|
|
||||||
if maybe_smask: # TODO: Use monad.
|
def get_soft_mask_reference(cross_reference: int) -> Either:
|
||||||
return any(
|
def error(value) -> str:
|
||||||
[load_image_handle_from_xref(doc, maybe_smask) is not None, bool(extract_pixmap(doc, maybe_smask).alpha)]
|
return f"Invalid soft mask {value} for cross reference {cross_reference}."
|
||||||
)
|
|
||||||
else:
|
logger.trace(f"Getting soft mask handle for cross reference {cross_reference}.")
|
||||||
try:
|
pass_on_if_not_none = iffy(notnone, right(identity), left(error))
|
||||||
return bool(fitz.Pixmap(doc, xref).alpha)
|
return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none)
|
||||||
except ValueError:
|
|
||||||
logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '<no title>')}.")
|
def mask_exists(soft_mask_reference: int) -> Either:
|
||||||
return False
|
logger.trace(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.")
|
||||||
|
return _get_image_handle(soft_mask_reference).then(notnone)
|
||||||
|
|
||||||
|
def image_has_alpha_channel(reference: int) -> Either:
|
||||||
|
logger.trace(f"Checking if image with reference {reference} has alpha channel.")
|
||||||
|
return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool)
|
||||||
|
|
||||||
|
cross_reference = Right(xref)
|
||||||
|
soft_mask_reference = cross_reference.bind(get_soft_mask_reference)
|
||||||
|
|
||||||
|
return any(
|
||||||
|
take_good_log_bad(reference.bind(check))
|
||||||
|
for reference, check in [
|
||||||
|
(soft_mask_reference, mask_exists),
|
||||||
|
(soft_mask_reference, image_has_alpha_channel),
|
||||||
|
(cross_reference, image_has_alpha_channel),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
|
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||||
@ -279,7 +297,7 @@ def normalize_channels(array: np.ndarray):
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TODO: Use monad.
|
def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]:
|
||||||
return doc.extract_image(xref)
|
return doc.extract_image(xref)
|
||||||
|
|
||||||
|
|
||||||
@ -302,7 +320,7 @@ def tiny(metadatum: dict) -> bool:
|
|||||||
|
|
||||||
def clear_caches() -> None:
|
def clear_caches() -> None:
|
||||||
get_image_infos.cache_clear()
|
get_image_infos.cache_clear()
|
||||||
load_image_handle_from_xref.cache_clear()
|
get_image_handle.cache_clear()
|
||||||
eith_extract_image.cache_clear()
|
eith_extract_image.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,15 @@
|
|||||||
|
from functools import wraps
|
||||||
|
from inspect import signature
|
||||||
from itertools import starmap
|
from itertools import starmap
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from funcy import iterate, first, curry, map
|
from funcy import iterate, first, curry, map
|
||||||
from pymonad.either import Left, Right
|
from pymonad.either import Left, Right, Either
|
||||||
|
from pymonad.tools import curry as pmcurry
|
||||||
|
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def until(cond, func, *args, **kwargs):
|
def until(cond, func, *args, **kwargs):
|
||||||
@ -17,12 +25,60 @@ def starlift(fn):
|
|||||||
|
|
||||||
|
|
||||||
def bottom(*args, **kwargs):
|
def bottom(*args, **kwargs):
|
||||||
return None
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def top(*args, **kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def left(fn):
|
def left(fn):
|
||||||
return lambda x: Left(fn(x))
|
@wraps(fn)
|
||||||
|
def inner(x):
|
||||||
|
return Left(fn(x))
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def right(fn):
|
def right(fn):
|
||||||
return lambda x: Right(fn(x))
|
@wraps(fn)
|
||||||
|
def inner(x):
|
||||||
|
return Right(fn(x))
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_left(fn, success_condition=top, error_message=None) -> Callable:
|
||||||
|
return wrap_either(Left, Right, success_condition=success_condition, error_message=error_message)(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_right(fn, success_condition=top, error_message=None) -> Callable:
|
||||||
|
return wrap_either(Right, Left, success_condition=success_condition, error_message=error_message)(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_either(success_type, failure_type, success_condition=top, error_message=None) -> Callable:
|
||||||
|
@wraps(wrap_either)
|
||||||
|
def wrapper(fn) -> Callable:
|
||||||
|
|
||||||
|
n_params = len(signature(fn).parameters)
|
||||||
|
|
||||||
|
@pmcurry(n_params)
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs) -> Either:
|
||||||
|
try:
|
||||||
|
result = fn(*args, **kwargs)
|
||||||
|
if success_condition(result):
|
||||||
|
return success_type(result)
|
||||||
|
else:
|
||||||
|
return failure_type({"error": error_message, "result": result})
|
||||||
|
except Exception as err:
|
||||||
|
logger.error(err)
|
||||||
|
return failure_type({"error": error_message or err, "result": Void})
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class Void:
|
||||||
|
pass
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user