Refactoring

- Rename
- Refactor image extraction functions
This commit is contained in:
Matthias Bisping 2023-02-07 14:32:29 +01:00
parent 76b1b0ca24
commit e99e97e23f
2 changed files with 38 additions and 25 deletions

View File

@ -1,9 +1,9 @@
import atexit import atexit
from _operator import itemgetter from _operator import itemgetter
from functools import partial, lru_cache from functools import partial, lru_cache, singledispatch
from itertools import chain, filterfalse from itertools import chain, filterfalse
from operator import itemgetter from operator import itemgetter
from typing import List, Union, Iterable from typing import List, Union, Iterable, Any
import fitz import fitz
import numpy as np import numpy as np
@ -19,14 +19,15 @@ from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import validate_box from image_prediction.stitching.utils import validate_box
from image_prediction.utils import get_logger from image_prediction.utils import get_logger
from image_prediction.utils.generic import bottom, left, right from image_prediction.utils.generic import bottom, left, right, lift
logger = get_logger() logger = get_logger()
Doc = fitz.fitz.Document Doc = fitz.fitz.Document
Pag = fitz.fitz.Page Pge = fitz.fitz.Page
Img = Image.Image Img = Image.Image
Pxm = fitz.fitz.Pixmap
class ParsablePDFImageExtractor(ImageExtractor): class ParsablePDFImageExtractor(ImageExtractor):
@ -51,7 +52,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
yield from image_metadata_pairs yield from image_metadata_pairs
def __process_images_on_page(self, page: Pag): def __process_images_on_page(self, page: Pge):
metadata = extract_valid_metadata(self.doc, page) metadata = extract_valid_metadata(self.doc, page)
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata) either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
@ -96,7 +97,7 @@ def validate_image(image: Img) -> Either:
return Left("Invalid image.") return Left("Invalid image.")
def extract_valid_metadata(doc: Doc, page: Pag) -> List[dict]: def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
return compose( return compose(
list, list,
partial(add_alpha_channel_info, doc), partial(add_alpha_channel_info, doc),
@ -106,8 +107,8 @@ def extract_valid_metadata(doc: Doc, page: Pag) -> List[dict]:
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either: def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
image: Either = xref_to_image(doc, metadatum[Info.XREF]).bind(validate_image) image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
image_metadata_pair: Either = make_eithered_image_metadata_pair(image, Right(metadatum)) image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
return image_metadata_pair return image_metadata_pair
@ -127,21 +128,24 @@ def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]: def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
metadata = map(get_image_metadata, get_image_infos(page)) metadata = compose(
metadata = add_page_metadata(page, metadata) partial(add_page_metadata, page),
lift(get_image_metadata),
get_image_infos,
)(page)
yield from metadata yield from metadata
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def xref_to_image(doc: Doc, xref: int) -> Either: def eith_extract_image(doc: Doc, xref: int) -> Either:
try: try:
return Right(extract_image(doc, xref)) return Right(extract_image(doc, xref))
except BadXref: except BadXref:
return Left("Bad xref.") return Left("Bad xref.")
def make_eithered_image_metadata_pair(image: Either, metadata: Either) -> Either: def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either:
"""Reference: haskell.org/tutorial/monads.html""" """Reference: haskell.org/tutorial/monads.html"""
def context(value): def context(value):
@ -180,18 +184,25 @@ def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
return ImageMetadataPair(image, metadatum) return ImageMetadataPair(image, metadatum)
def extract_image(doc: Doc, xref: int) -> Img: @singledispatch
def extract_image(doc: Doc, xref: int) -> Any:
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
def pixmap_to_image(pixmap: Pxm) -> Img:
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
array = normalize_channels(array)
return Image.fromarray(array)
def extract_pixmap(doc: Doc, xref: int) -> Pxm:
try: try:
pixmap = fitz.Pixmap(doc, xref) return fitz.Pixmap(doc, xref)
except ValueError as err: except ValueError as err:
msg = f"Xref {xref} is invalid, skipping extraction." msg = f"Xref {xref} is invalid, skipping extraction."
logger.debug(msg) logger.debug(msg)
raise BadXref(msg) from err raise BadXref(msg) from err
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
array = normalize_channels(array)
return Image.fromarray(array)
def has_alpha_channel(doc: Doc, xref: int): def has_alpha_channel(doc: Doc, xref: int):
@ -199,7 +210,9 @@ def has_alpha_channel(doc: Doc, xref: int):
maybe_smask = maybe_image["smask"] if maybe_image else None maybe_smask = maybe_image["smask"] if maybe_image else None
if maybe_smask: # TODO: Use monad. if maybe_smask: # TODO: Use monad.
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) return any(
[load_image_handle_from_xref(doc, maybe_smask) is not None, bool(extract_pixmap(doc, maybe_smask).alpha)]
)
else: else:
try: try:
return bool(fitz.Pixmap(doc, xref).alpha) return bool(fitz.Pixmap(doc, xref).alpha)
@ -242,11 +255,11 @@ def get_image_metadata(image_info: dict) -> dict:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_image_infos(page: Pag) -> List[dict]: def get_image_infos(page: Pge) -> List[dict]:
return page.get_image_info(xrefs=True) return page.get_image_info(xrefs=True)
def add_page_metadata(page: Pag, metadata: Iterable[dict]) -> Iterable[dict]: def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]:
yield from map(partial(merge, get_page_metadata(page)), metadata) yield from map(partial(merge, get_page_metadata(page)), metadata)
@ -270,7 +283,7 @@ def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TO
return doc.extract_image(xref) return doc.extract_image(xref)
def get_page_metadata(page: Pag) -> dict: def get_page_metadata(page: Pge) -> dict:
page_width, page_height = map(rounder, page.mediabox_size) page_width, page_height = map(rounder, page.mediabox_size)
return { return {
@ -290,7 +303,7 @@ def tiny(metadatum: dict) -> bool:
def clear_caches() -> None: def clear_caches() -> None:
get_image_infos.cache_clear() get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear() load_image_handle_from_xref.cache_clear()
xref_to_image.cache_clear() eith_extract_image.cache_clear()
atexit.register(clear_caches) atexit.register(clear_caches)

View File

@ -18,7 +18,7 @@ from image_prediction.image_extractor.extractors.parsable import (
get_image_infos, get_image_infos,
ParsablePDFImageExtractor, ParsablePDFImageExtractor,
extract_valid_metadata, extract_valid_metadata,
xref_to_image, eith_extract_image,
extract_image, extract_image,
) )
from image_prediction.info import Info from image_prediction.info import Info
@ -98,4 +98,4 @@ def test_bad_xref_handling(bad_xref_pdf, dvc_test_data):
with pytest.raises(BadXref): with pytest.raises(BadXref):
extract_image(doc, xref) extract_image(doc, xref)
assert xref_to_image(doc, xref).is_left() assert eith_extract_image(doc, xref).is_left()