diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 7ce3f6a..d3a5adf 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,9 +1,9 @@ import atexit from _operator import itemgetter -from functools import partial, lru_cache +from functools import partial, lru_cache, singledispatch from itertools import chain, filterfalse from operator import itemgetter -from typing import List, Union, Iterable +from typing import List, Union, Iterable, Any import fitz import numpy as np @@ -19,14 +19,15 @@ from image_prediction.info import Info from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.utils import validate_box from image_prediction.utils import get_logger -from image_prediction.utils.generic import bottom, left, right +from image_prediction.utils.generic import bottom, left, right, lift logger = get_logger() Doc = fitz.fitz.Document -Pag = fitz.fitz.Page +Pge = fitz.fitz.Page Img = Image.Image +Pxm = fitz.fitz.Pixmap class ParsablePDFImageExtractor(ImageExtractor): @@ -51,7 +52,7 @@ class ParsablePDFImageExtractor(ImageExtractor): yield from image_metadata_pairs - def __process_images_on_page(self, page: Pag): + def __process_images_on_page(self, page: Pge): metadata = extract_valid_metadata(self.doc, page) either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata) @@ -96,7 +97,7 @@ def validate_image(image: Img) -> Either: return Left("Invalid image.") -def extract_valid_metadata(doc: Doc, page: Pag) -> List[dict]: +def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]: return compose( list, partial(add_alpha_channel_info, doc), @@ -106,8 +107,8 @@ def extract_valid_metadata(doc: Doc, page: Pag) -> List[dict]: def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either: - image: Either = xref_to_image(doc, metadatum[Info.XREF]).bind(validate_image) - image_metadata_pair: Either = make_eithered_image_metadata_pair(image, Right(metadatum)) + image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image) + image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum)) return image_metadata_pair @@ -127,21 +128,24 @@ def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]: def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]: - metadata = map(get_image_metadata, get_image_infos(page)) - metadata = add_page_metadata(page, metadata) + metadata = compose( + partial(add_page_metadata, page), + lift(get_image_metadata), + get_image_infos, + )(page) yield from metadata @lru_cache(maxsize=None) -def xref_to_image(doc: Doc, xref: int) -> Either: +def eith_extract_image(doc: Doc, xref: int) -> Either: try: return Right(extract_image(doc, xref)) except BadXref: return Left("Bad xref.") -def make_eithered_image_metadata_pair(image: Either, metadata: Either) -> Either: +def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either: """Reference: haskell.org/tutorial/monads.html""" def context(value): @@ -180,18 +184,25 @@ def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair: return ImageMetadataPair(image, metadatum) -def extract_image(doc: Doc, xref: int) -> Img: +@singledispatch +def extract_image(doc: Doc, xref: int) -> Any: + return compose(pixmap_to_image, extract_pixmap)(doc, xref) + + +def pixmap_to_image(pixmap: Pxm) -> Img: + array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n)) + array = normalize_channels(array) + return Image.fromarray(array) + + +def extract_pixmap(doc: Doc, xref: int) -> Pxm: try: - pixmap = fitz.Pixmap(doc, xref) + return fitz.Pixmap(doc, xref) except ValueError as err: msg = f"Xref {xref} is invalid, skipping extraction." logger.debug(msg) raise BadXref(msg) from err - array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n)) - array = normalize_channels(array) - return Image.fromarray(array) - def has_alpha_channel(doc: Doc, xref: int): @@ -199,7 +210,9 @@ def has_alpha_channel(doc: Doc, xref: int): maybe_smask = maybe_image["smask"] if maybe_image else None if maybe_smask: # TODO: Use monad. - return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) + return any( + [load_image_handle_from_xref(doc, maybe_smask) is not None, bool(extract_pixmap(doc, maybe_smask).alpha)] + ) else: try: return bool(fitz.Pixmap(doc, xref).alpha) @@ -242,11 +255,11 @@ def get_image_metadata(image_info: dict) -> dict: @lru_cache(maxsize=None) -def get_image_infos(page: Pag) -> List[dict]: +def get_image_infos(page: Pge) -> List[dict]: return page.get_image_info(xrefs=True) -def add_page_metadata(page: Pag, metadata: Iterable[dict]) -> Iterable[dict]: +def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]: yield from map(partial(merge, get_page_metadata(page)), metadata) @@ -270,7 +283,7 @@ def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TO return doc.extract_image(xref) -def get_page_metadata(page: Pag) -> dict: +def get_page_metadata(page: Pge) -> dict: page_width, page_height = map(rounder, page.mediabox_size) return { @@ -290,7 +303,7 @@ def tiny(metadatum: dict) -> bool: def clear_caches() -> None: get_image_infos.cache_clear() load_image_handle_from_xref.cache_clear() - xref_to_image.cache_clear() + eith_extract_image.cache_clear() atexit.register(clear_caches) diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index b27744b..1c34103 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -18,7 +18,7 @@ from image_prediction.image_extractor.extractors.parsable import ( get_image_infos, ParsablePDFImageExtractor, extract_valid_metadata, - xref_to_image, + eith_extract_image, extract_image, ) from image_prediction.info import Info @@ -98,4 +98,4 @@ def test_bad_xref_handling(bad_xref_pdf, dvc_test_data): with pytest.raises(BadXref): extract_image(doc, xref) - assert xref_to_image(doc, xref).is_left() + assert eith_extract_image(doc, xref).is_left()