Refactoring

- Rename
- Add typehints everywhere
This commit is contained in:
Matthias Bisping 2023-02-07 10:11:35 +01:00
parent 865e0819a1
commit 3cea4dad2d
2 changed files with 53 additions and 42 deletions

View File

@ -3,12 +3,12 @@ from _operator import itemgetter
from functools import partial, lru_cache from functools import partial, lru_cache
from itertools import chain, filterfalse from itertools import chain, filterfalse
from operator import itemgetter from operator import itemgetter
from typing import List, Union from typing import List, Union, Iterable
import fitz import fitz
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from funcy import merge, compose, rcompose, keep from funcy import merge, compose, rcompose, keep, lkeep
from pymonad.either import Right, Left, Either from pymonad.either import Right, Left, Either
from pymonad.tools import curry, identity from pymonad.tools import curry, identity
@ -24,6 +24,11 @@ from image_prediction.utils.generic import bottom, left, right
logger = get_logger() logger = get_logger()
Doc = fitz.fitz.Document
Pag = fitz.fitz.Page
Img = Image.Image
class ParsablePDFImageExtractor(ImageExtractor): class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self, verbose=False, tolerance=0): def __init__(self, verbose=False, tolerance=0):
""" """
@ -33,7 +38,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched
together together
""" """
self.doc: Union[fitz.fitz.Document, None] = None self.doc: Union[Doc, None] = None
self.verbose = verbose self.verbose = verbose
self.tolerance = tolerance self.tolerance = tolerance
@ -46,36 +51,42 @@ class ParsablePDFImageExtractor(ImageExtractor):
yield from image_metadata_pairs yield from image_metadata_pairs
def __process_images_on_page(self, page: fitz.fitz.Page): def __process_images_on_page(self, page: Pag):
metadata = extract_valid_metadata(self.doc, page) metadata = extract_valid_metadata(self.doc, page)
maybe_image_metadata_pairs = map(partial(metadatum_to_image_metadata_pair, self.doc), metadata) either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
image_metadata_pairs = keep(take_right, maybe_image_metadata_pairs) valid_image_metadata_pairs = lkeep(take_good_log_bad, either_image_metadata_pair_or_error_per_image)
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
clear_caches() clear_caches()
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) yield from valid_image_metadata_pairs_stitched
yield from image_metadata_pairs def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either:
return metadatum_to_image_metadata_pair(self.doc, metadatum)
def take_right(pair: Either): def take_good_log_bad(pair: Either) -> Union[ImageMetadataPair, None]:
if pair.is_right(): return pair.either(log_error_context, identity)
return pair.either(bottom, identity)
logger.warning(f"Skipping bad image. {pair.either(format_context, bottom)}")
def format_context(context): def log_error_context(context: dict) -> None:
logger.warning(f"Skipping bad image. {format_context(context)}")
return None
def format_context(context: dict) -> str:
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}" return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
def extract_pages(doc, page_range): def extract_pages(doc: Doc, page_range: range):
page_range = range(page_range.start + 1, page_range.stop + 1) page_range = range(page_range.start + 1, page_range.stop + 1)
pages = map(doc.load_page, page_range) pages = map(doc.load_page, page_range)
yield from pages yield from pages
def validate_image(image: Image.Image) -> Either: def validate_image(image: Img) -> Either:
try: try:
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148) # TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
image.resize((100, 100)).convert("RGB") image.resize((100, 100)).convert("RGB")
@ -85,7 +96,7 @@ def validate_image(image: Image.Image) -> Either:
return Left("Invalid image.") return Left("Invalid image.")
def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page): def extract_valid_metadata(doc: Doc, page: Pag) -> List[dict]:
return compose( return compose(
list, list,
partial(add_alpha_channel_info, doc), partial(add_alpha_channel_info, doc),
@ -94,14 +105,14 @@ def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page):
)(page) )(page)
def metadatum_to_image_metadata_pair(doc, metadatum: dict) -> Either: def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
maybe_image = xref_to_maybe_image(doc, metadatum[Info.XREF]).bind(validate_image) image: Either = xref_to_image(doc, metadatum[Info.XREF]).bind(validate_image)
maybe_image_metadata_pair = make_maybe_image_metadata_pair(maybe_image, Right(metadatum)) image_metadata_pair: Either = make_eithered_image_metadata_pair(image, Right(metadatum))
return maybe_image_metadata_pair return image_metadata_pair
def add_alpha_channel_info(doc, metadata): def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]:
def add_alpha_value_to_metadatum(metadatum): def add_alpha_value_to_metadatum(metadatum: dict) -> dict:
alpha = metadatum_to_alpha_value(metadatum) alpha = metadatum_to_alpha_value(metadatum)
return {**metadatum, Info.ALPHA: alpha} return {**metadatum, Info.ALPHA: alpha}
@ -111,11 +122,11 @@ def add_alpha_channel_info(doc, metadata):
yield from map(add_alpha_value_to_metadatum, metadata) yield from map(add_alpha_value_to_metadatum, metadata)
def filter_valid_metadata(metadata): def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata) yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata)
def get_metadata_for_images_on_page(page: fitz.Page): def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
metadata = map(get_image_metadata, get_image_infos(page)) metadata = map(get_image_metadata, get_image_infos(page))
metadata = add_page_metadata(page, metadata) metadata = add_page_metadata(page, metadata)
@ -123,14 +134,14 @@ def get_metadata_for_images_on_page(page: fitz.Page):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def xref_to_maybe_image(doc, xref) -> Either: def xref_to_image(doc: Doc, xref: int) -> Either:
try: try:
return Right(extract_image(doc, xref)) return Right(extract_image(doc, xref))
except BadXref: except BadXref:
return Left("Bad xref.") return Left("Bad xref.")
def make_maybe_image_metadata_pair(image: Either, metadata: Either): def make_eithered_image_metadata_pair(image: Either, metadata: Either) -> Either:
"""Reference: haskell.org/tutorial/monads.html""" """Reference: haskell.org/tutorial/monads.html"""
def context(value): def context(value):
@ -160,11 +171,11 @@ def make_maybe_image_metadata_pair(image: Either, metadata: Either):
@curry(2) @curry(2)
def make_image_metadata_pair(image: Image.Image, metadatum: dict) -> ImageMetadataPair: def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
return ImageMetadataPair(image, metadatum) return ImageMetadataPair(image, metadatum)
def extract_image(doc, xref) -> Image.Image: def extract_image(doc: Doc, xref: int) -> Img:
try: try:
pixmap = fitz.Pixmap(doc, xref) pixmap = fitz.Pixmap(doc, xref)
except ValueError as err: except ValueError as err:
@ -177,12 +188,12 @@ def extract_image(doc, xref) -> Image.Image:
return Image.fromarray(array) return Image.fromarray(array)
def has_alpha_channel(doc, xref): def has_alpha_channel(doc: Doc, xref: int):
maybe_image = load_image_handle_from_xref(doc, xref) maybe_image = load_image_handle_from_xref(doc, xref)
maybe_smask = maybe_image["smask"] if maybe_image else None maybe_smask = maybe_image["smask"] if maybe_image else None
if maybe_smask: if maybe_smask: # Use monad
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
else: else:
try: try:
@ -192,11 +203,11 @@ def has_alpha_channel(doc, xref):
return False return False
def filter_out_tiny_images(metadata): def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
yield from filterfalse(tiny, metadata) yield from filterfalse(tiny, metadata)
def filter_out_invalid_metadata(metadata): def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
def __validate_box(box): def __validate_box(box):
try: try:
return validate_box(box) return validate_box(box)
@ -206,7 +217,7 @@ def filter_out_invalid_metadata(metadata):
yield from keep(__validate_box, metadata) yield from keep(__validate_box, metadata)
def get_image_metadata(image_info): def get_image_metadata(image_info: dict) -> dict:
xref, coords = itemgetter("xref", "bbox")(image_info) xref, coords = itemgetter("xref", "bbox")(image_info)
x1, y1, x2, y2 = map(rounder, coords) x1, y1, x2, y2 = map(rounder, coords)
@ -226,11 +237,11 @@ def get_image_metadata(image_info):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_image_infos(page: fitz.Page) -> List[dict]: def get_image_infos(page: Pag) -> List[dict]:
return page.get_image_info(xrefs=True) return page.get_image_info(xrefs=True)
def add_page_metadata(page, metadata): def add_page_metadata(page: Pag, metadata: Iterable[dict]) -> Iterable[dict]:
yield from map(partial(merge, get_page_metadata(page)), metadata) yield from map(partial(merge, get_page_metadata(page)), metadata)
@ -250,11 +261,11 @@ def normalize_channels(array: np.ndarray):
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref): def load_image_handle_from_xref(doc: Doc, xref: int) -> Union[dict, None]: # TODO: use Monad
return doc.extract_image(xref) return doc.extract_image(xref)
def get_page_metadata(page): def get_page_metadata(page: Pag) -> dict:
page_width, page_height = map(rounder, page.mediabox_size) page_width, page_height = map(rounder, page.mediabox_size)
return { return {
@ -267,14 +278,14 @@ def get_page_metadata(page):
rounder = rcompose(round, int) rounder = rcompose(round, int)
def tiny(metadatum): def tiny(metadatum: dict) -> bool:
return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4 return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4
def clear_caches(): def clear_caches() -> None:
get_image_infos.cache_clear() get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear() load_image_handle_from_xref.cache_clear()
xref_to_maybe_image.cache_clear() xref_to_image.cache_clear()
atexit.register(clear_caches) atexit.register(clear_caches)

View File

@ -18,7 +18,7 @@ from image_prediction.image_extractor.extractors.parsable import (
get_image_infos, get_image_infos,
ParsablePDFImageExtractor, ParsablePDFImageExtractor,
extract_valid_metadata, extract_valid_metadata,
xref_to_maybe_image, xref_to_image,
extract_image, extract_image,
) )
from image_prediction.info import Info from image_prediction.info import Info
@ -98,4 +98,4 @@ def test_bad_xref_handling(bad_xref_pdf, dvc_test_data):
with pytest.raises(BadXref): with pytest.raises(BadXref):
extract_image(doc, xref) extract_image(doc, xref)
assert xref_to_maybe_image(doc, xref).is_left() assert xref_to_image(doc, xref).is_left()