diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 31e8662..59ac8a5 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,32 +1,29 @@ import atexit +import json +import traceback +from _operator import itemgetter from functools import partial, lru_cache -from itertools import chain, filterfalse -from operator import itemgetter, attrgetter -from typing import List, Union, Iterable, Any +from itertools import chain, starmap, filterfalse +from operator import itemgetter, truth +from typing import Iterable, Iterator, List, Union import fitz import numpy as np from PIL import Image -from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial -from pymonad.either import Right, Left, Either -from pymonad.tools import curry, identity +from funcy import merge, pluck, compose, rcompose, remove, keep +from image_prediction.config import CONFIG from image_prediction.exceptions import InvalidBox, BadXref from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.info import Info from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.utils import validate_box +from image_prediction.transformer.transformers.response import compute_geometric_quotient from image_prediction.utils import get_logger -from image_prediction.utils.generic import bottom, left, right, lift, wrap_right logger = get_logger() -Doc = fitz.fitz.Document -Pge = fitz.fitz.Page -Img = Image.Image -Pxm = fitz.fitz.Pixmap - class ParsablePDFImageExtractor(ImageExtractor): def __init__(self, verbose=False, tolerance=0): @@ -37,7 +34,7 @@ class ParsablePDFImageExtractor(ImageExtractor): tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched together """ - self.doc: Union[Doc, None] = None + self.doc: fitz.fitz.Document = None self.verbose = verbose self.tolerance = tolerance @@ -50,41 +47,51 @@ class ParsablePDFImageExtractor(ImageExtractor): yield from image_metadata_pairs - def __process_images_on_page(self, page: Pge): + def __process_images_on_page(self, page: fitz.fitz.Page): metadata = extract_valid_metadata(self.doc, page) - - either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata) - valid_image_metadata_pairs = lkeep( - rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image - ) - valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance) + images = get_images_on_page(self.doc, metadata) clear_caches() - yield from valid_image_metadata_pairs_stitched + image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata))) + # TODO: In the future, consider to introduce an image validator as a pipeline component rather than doing the + # validation here. Invalid images can then be split into a different stream and joined with the intact images + # again for the formatting step. + image_metadata_pairs = self.__filter_valid_images(image_metadata_pairs) + image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) - def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either: - return metadatum_to_image_metadata_pair(self.doc, metadatum) + yield from image_metadata_pairs + + @staticmethod + def __filter_valid_images(image_metadata_pairs: Iterable[ImageMetadataPair]) -> Iterator[ImageMetadataPair]: + def validate(image: Image.Image, metadata: dict): + try: + # TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148) + image.resize((100, 100)).convert("RGB") + return ImageMetadataPair(image, metadata) + except (OSError, Exception) as err: + metadata = json.dumps(EnumFormatter()(metadata), indent=2) + logger.warning(f"Invalid image encountered. Image metadata:\n{metadata}\n\n{traceback.format_exc()}") + return None + + return filter(truth, starmap(validate, image_metadata_pairs)) -def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]: +def extract_pages(doc, page_range): page_range = range(page_range.start + 1, page_range.stop + 1) pages = map(doc.load_page, page_range) yield from pages -def validate_image(image: Img) -> Either: - try: - # TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148) - image.resize((100, 100)).convert("RGB") - return Right(image) - except (OSError, Exception): - logger.warning(f"Invalid image encountered.") - return Left("Invalid image.") +def get_images_on_page(doc, metadata): + xrefs = pluck(Info.XREF, metadata) + images = map(partial(xref_to_image, doc), xrefs) + + yield from images -def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]: +def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page): return compose( list, partial(add_alpha_channel_info, doc), @@ -93,159 +100,25 @@ def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]: )(page) -def take_good_log_bad(item: Either, log_formatter=identity) -> Any: - return item.either(rpartial(log_error_context, log_formatter), identity) - - -def format_context(context: dict) -> str: - return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}" - - -def log_error_context(context: dict, formatter=identity) -> None: - logger.warning(f"Skipping bad image. {formatter(context)}") - return None - - -def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either: - image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image) - image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum)) - return image_metadata_pair - - -def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]: - def add_alpha_value_to_metadatum(metadatum: dict) -> dict: - alpha = metadatum_to_alpha_value(metadatum) - return {**metadatum, Info.ALPHA: alpha} - - xref_to_alpha = partial(has_alpha_channel, doc) - metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF)) - - yield from map(add_alpha_value_to_metadatum, metadata) - - -def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]: - yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata) - - -def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]: - metadata = compose( - partial(add_page_metadata, page), - lift(get_image_metadata), - get_image_infos, - )(page) +def get_metadata_for_images_on_page(page: fitz.Page): + metadata = map(get_image_metadata, get_image_infos(page)) + metadata = add_page_metadata(page, metadata) yield from metadata -@lru_cache(maxsize=None) -@curry(2) -def eith_extract_image(doc: Doc, xref: int) -> Either: - try: - return Right(extract_image(doc, xref)) - except BadXref: - return Left("Bad xref.") +def filter_valid_metadata(metadata): + yield from compose( + # TODO: Disabled for now, since atm since the backend needs atm the metadata and the hash of every image, even + # scanned pages. In the future, this should be resolved differently, e.g. by filtering all page-sized images + # and giving the user the ability to reclassify false positives with a separate call. + # filter_out_page_sized_images, + filter_out_tiny_images, + filter_out_invalid_metadata, + )(metadata) -def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either: - """Reference: haskell.org/tutorial/monads.html""" - - def context(value): - return {"reason": value, "metadata": metadata.either(bottom, identity)} - - # Explicitly we are doing the following. (1) and (2) are equivalent. - - # a := Image - # b := Metadata - # c := ImageMetadataPair - # m := Either monad - - # fmt: off - # 1) - # pair: Either = ( - # Right(make_image_metadata_pair) # m (a -> b -> c) - # .amap(image) # m (a -> b -> c) <*> m a = m (b -> c) - # .amap(metadata) # m (b -> c) <*> m b = m c - # ) - - # 2) - # pair: Either = ( - # image.bind(right(make_image_metadata_pair)) # m a >>= m (a -> b -> c) = m (b -> c) - # .amap(metadata) # m (b -> c) <*> m b = m c - # ) - # fmt: on - - # Syntactic sugar variant with details hidden - pair: Either = Either.apply(make_image_metadata_pair).to_arguments(image, metadata) - - return pair.either(left(context), right(identity)) - - -@curry(2) -def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair: - return ImageMetadataPair(image, metadatum) - - -def extract_image(doc: Doc, xref: int) -> Any: - return compose(pixmap_to_image, extract_pixmap)(doc, xref) - - -def pixmap_to_image(pixmap: Pxm) -> Img: - array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n)) - array = normalize_channels(array) - return Image.fromarray(array) - - -def extract_pixmap(doc: Doc, xref: int) -> Pxm: - try: - return fitz.Pixmap(doc, xref) - except ValueError as err: - msg = f"Cross reference {xref} is invalid, skipping extraction." - logger.error(err) - logger.debug(msg) - raise BadXref(msg) from err - - -def has_alpha_channel(doc: Doc, xref: int) -> bool: - - _get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc) - _extract_pixmap = wrap_right(extract_pixmap)(doc) - - def get_soft_mask_reference(cross_reference: int) -> Either: - def error(value) -> str: - return f"Invalid soft mask {value} for cross reference {cross_reference}." - - logger.debug(f"Getting soft mask handle for cross reference {cross_reference}.") - pass_on_if_not_none = iffy(notnone, right(identity), left(error)) - return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none) - - def mask_exists(soft_mask_reference: int) -> Either: - logger.debug(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.") - return _get_image_handle(soft_mask_reference).then(notnone) - - def image_has_alpha_channel(reference: int) -> Either: - logger.debug(f"Checking if image with reference {reference} has alpha channel.") - return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool) - - logger.debug(f"Checking if image with cross reference {xref} has alpha channel.") - - cross_reference = Right(xref) - soft_mask_reference = cross_reference.bind(get_soft_mask_reference) - - return any( - take_good_log_bad(reference.bind(check)) - for reference, check in [ - (soft_mask_reference, mask_exists), - (soft_mask_reference, image_has_alpha_channel), - (cross_reference, image_has_alpha_channel), - ] - ) - - -def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]: - yield from filterfalse(tiny, metadata) - - -def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]: +def filter_out_invalid_metadata(metadata): def __validate_box(box): try: return validate_box(box) @@ -255,7 +128,50 @@ def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]: yield from keep(__validate_box, metadata) -def get_image_metadata(image_info: dict) -> dict: +def filter_out_page_sized_images(metadata): + yield from remove(breaches_image_to_page_quotient, metadata) + + +def filter_out_tiny_images(metadata): + yield from filterfalse(tiny, metadata) + + +@lru_cache(maxsize=None) +def get_image_infos(page: fitz.Page) -> List[dict]: + return page.get_image_info(xrefs=True) + + +@lru_cache(maxsize=None) +def xref_to_image(doc, xref) -> Union[Image.Image, None]: + # NOTE: image extraction is done via pixmap to array, as this method is twice as fast as extraction via bytestream + try: + pixmap = fitz.Pixmap(doc, xref) + array = convert_pixmap_to_array(pixmap) + return Image.fromarray(array) + except ValueError: + logger.debug(f"Xref {xref} is invalid, skipping extraction ...") + return + + +def convert_pixmap_to_array(pixmap: fitz.fitz.Pixmap): + array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape(pixmap.h, pixmap.w, pixmap.n) + array = _normalize_channels(array) + return array + + +def _normalize_channels(array: np.ndarray): + if array.shape[-1] == 1: + array = array[:, :, 0] + elif array.shape[-1] == 4: + array = array[..., :3] + elif array.shape[-1] != 3: + logger.warning(f"Unexpected image format: {array.shape}.") + raise ValueError(f"Unexpected image format: {array.shape}.") + + return array + + +def get_image_metadata(image_info): xref, coords = itemgetter("xref", "bbox")(image_info) x1, y1, x2, y2 = map(rounder, coords) @@ -274,36 +190,30 @@ def get_image_metadata(image_info: dict) -> dict: } -@lru_cache(maxsize=None) -def get_image_infos(page: Pge) -> List[dict]: - return page.get_image_info(xrefs=True) - - -def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]: +def add_page_metadata(page, metadata): yield from map(partial(merge, get_page_metadata(page)), metadata) -def normalize_channels(array: np.ndarray): - if not array.ndim == 3: - array = np.expand_dims(array, axis=-1) +def add_alpha_channel_info(doc, metadata): + def add_alpha_value_to_metadatum(metadatum): + alpha = metadatum_to_alpha_value(metadatum) + return {**metadatum, Info.ALPHA: alpha} - if array.shape[-1] == 4: - array = array[..., :3] - elif array.shape[-1] == 1: - array = np.concatenate([array, array, array], axis=-1) - elif array.shape[-1] != 3: - logger.warning(f"Unexpected image format: {array.shape}.") - raise ValueError(f"Unexpected image format: {array.shape}.") + xref_to_alpha = partial(has_alpha_channel, doc) + metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF)) - return array + yield from map(add_alpha_value_to_metadatum, metadata) @lru_cache(maxsize=None) -def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]: +def load_image_handle_from_xref(doc, xref): return doc.extract_image(xref) -def get_page_metadata(page: Pge) -> dict: +rounder = rcompose(round, int) + + +def get_page_metadata(page): page_width, page_height = map(rounder, page.mediabox_size) return { @@ -313,17 +223,38 @@ def get_page_metadata(page: Pge) -> dict: } -rounder = rcompose(round, int) +def has_alpha_channel(doc, xref): + + maybe_image = load_image_handle_from_xref(doc, xref) + maybe_smask = maybe_image["smask"] if maybe_image else None + + if maybe_smask: + return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) + else: + try: + return bool(fitz.Pixmap(doc, xref).alpha) + except ValueError: + logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '')}.") + return False -def tiny(metadatum: dict) -> bool: - return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4 +def tiny(metadata): + return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4 -def clear_caches() -> None: +def clear_caches(): get_image_infos.cache_clear() - get_image_handle.cache_clear() - eith_extract_image.cache_clear() + load_image_handle_from_xref.cache_clear() + xref_to_image.cache_clear() atexit.register(clear_caches) + + +def breaches_image_to_page_quotient(metadatum): + page_width, page_height, x1, x2, y1, y2, width, height = itemgetter( + Info.PAGE_WIDTH, Info.PAGE_HEIGHT, Info.X1, Info.X2, Info.Y1, Info.Y2, Info.WIDTH, Info.HEIGHT + )(metadatum) + geometric_quotient = compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1) + quotient_breached = bool(geometric_quotient > CONFIG.filters.image_to_page_quotient.max) + return quotient_breached diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index 1c34103..92a705a 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -1,4 +1,3 @@ -import json import random from operator import itemgetter @@ -8,23 +7,18 @@ import pytest from PIL import Image from funcy import first, rest -from image_prediction.exceptions import BadXref from image_prediction.extraction import extract_images_from_pdf -from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.parsable import ( extract_pages, has_alpha_channel, get_image_infos, - ParsablePDFImageExtractor, extract_valid_metadata, - eith_extract_image, - extract_image, + xref_to_image, ) from image_prediction.info import Info -from image_prediction.locations import TEST_DATA_DIR from test.utils.comparison import metadata_equal, image_sets_equal -from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes +from test.utils.generation.pdf import add_image, pdf_stream @pytest.mark.parametrize("extractor_type", ["mock"]) @@ -95,7 +89,4 @@ def test_bad_xref_handling(bad_xref_pdf, dvc_test_data): metadata = extract_valid_metadata(doc, first(doc)) xref = first(metadata)[Info.XREF] - with pytest.raises(BadXref): - extract_image(doc, xref) - - assert eith_extract_image(doc, xref).is_left() + assert not xref_to_image(doc, xref)