Matthias Bisping 89989543d8 [WIP] Monadic refactoring
Integrate image validation step into monadic chain.

At the moment we lost the error information through this. Refactoring to
Either monad can bring it back.
2023-02-06 16:12:41 +01:00

246 lines
7.1 KiB
Python

import atexit
from _operator import itemgetter
from functools import partial, lru_cache
from itertools import chain, filterfalse
from operator import itemgetter
from typing import List, Union
import fitz
import numpy as np
from PIL import Image
from funcy import merge, compose, rcompose, keep
from pymonad.maybe import Maybe, Nothing, Just
from pymonad.tools import curry
from image_prediction.exceptions import InvalidBox, BadXref
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import validate_box
from image_prediction.utils import get_logger
logger = get_logger()
class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self, verbose=False, tolerance=0):
"""
Args:
verbose: Whether to show progressbar
tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched
together
"""
self.doc: Union[fitz.fitz.Document, None] = None
self.verbose = verbose
self.tolerance = tolerance
def extract(self, pdf: bytes, page_range: range = None):
self.doc = fitz.Document(stream=pdf)
pages = extract_pages(self.doc, page_range) if page_range else self.doc
image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages))
yield from image_metadata_pairs
def __process_images_on_page(self, page: fitz.fitz.Page):
metadata = extract_valid_metadata(self.doc, page)
maybe_image_metadata_pairs = map(partial(metadatum_to_image_metadata_pair, self.doc), metadata)
image_metadata_pairs = [pair.value for pair in maybe_image_metadata_pairs if pair.is_just()]
clear_caches()
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
yield from image_metadata_pairs
def extract_pages(doc, page_range):
page_range = range(page_range.start + 1, page_range.stop + 1)
pages = map(doc.load_page, page_range)
yield from pages
def validate_image(image: Image.Image) -> Maybe:
try:
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
image.resize((100, 100)).convert("RGB")
return Just(image)
except (OSError, Exception):
logger.warning(f"Invalid image encountered.")
return Nothing
def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page):
return compose(
list,
partial(add_alpha_channel_info, doc),
filter_valid_metadata,
get_metadata_for_images_on_page,
)(page)
def metadatum_to_image_metadata_pair(doc, metadatum: dict) -> Maybe:
maybe_image = xref_to_maybe_image(doc, metadatum[Info.XREF]).bind(validate_image)
maybe_image_metadata_pair = make_maybe_image_metadata_pair(maybe_image, Just(metadatum))
return maybe_image_metadata_pair
def add_alpha_channel_info(doc, metadata):
def add_alpha_value_to_metadatum(metadatum):
alpha = metadatum_to_alpha_value(metadatum)
return {**metadatum, Info.ALPHA: alpha}
xref_to_alpha = partial(has_alpha_channel, doc)
metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF))
yield from map(add_alpha_value_to_metadatum, metadata)
def filter_valid_metadata(metadata):
yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata)
def get_metadata_for_images_on_page(page: fitz.Page):
metadata = map(get_image_metadata, get_image_infos(page))
metadata = add_page_metadata(page, metadata)
yield from metadata
@lru_cache(maxsize=None)
def xref_to_maybe_image(doc, xref) -> Maybe:
try:
return Just(extract_image(doc, xref))
except BadXref:
return Nothing
def make_maybe_image_metadata_pair(image: Maybe, metadata: Maybe):
# haskell.org/tutorial/monads.html
# (>>) :: m a -> m b -> m b
return Just(make_image_metadata_pair).amap(image).amap(metadata)
@curry(2)
def make_image_metadata_pair(image: Image.Image, metadatum: dict) -> Just:
return ImageMetadataPair(image, metadatum)
def extract_image(doc, xref) -> Image.Image:
try:
pixmap = fitz.Pixmap(doc, xref)
except ValueError as err:
msg = f"Xref {xref} is invalid, skipping extraction."
logger.debug(msg)
raise BadXref(msg) from err
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
array = normalize_channels(array)
return Image.fromarray(array)
def has_alpha_channel(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
maybe_smask = maybe_image["smask"] if maybe_image else None
if maybe_smask:
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
else:
try:
return bool(fitz.Pixmap(doc, xref).alpha)
except ValueError:
logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '<no title>')}.")
return False
def filter_out_tiny_images(metadata):
yield from filterfalse(tiny, metadata)
def filter_out_invalid_metadata(metadata):
def __validate_box(box):
try:
return validate_box(box)
except InvalidBox as err:
logger.debug(f"Dropping invalid metadatum, reason: {err}")
yield from keep(__validate_box, metadata)
def get_image_metadata(image_info):
xref, coords = itemgetter("xref", "bbox")(image_info)
x1, y1, x2, y2 = map(rounder, coords)
width = abs(x2 - x1)
height = abs(y2 - y1)
return {
Info.WIDTH: width,
Info.HEIGHT: height,
Info.X1: x1,
Info.X2: x2,
Info.Y1: y1,
Info.Y2: y2,
Info.XREF: xref,
}
@lru_cache(maxsize=None)
def get_image_infos(page: fitz.Page) -> List[dict]:
return page.get_image_info(xrefs=True)
def add_page_metadata(page, metadata):
yield from map(partial(merge, get_page_metadata(page)), metadata)
def normalize_channels(array: np.ndarray):
if not array.ndim == 3:
array = np.expand_dims(array, axis=-1)
if array.shape[-1] == 4:
array = array[..., :3]
elif array.shape[-1] == 1:
array = np.concatenate([array, array, array], axis=-1)
elif array.shape[-1] != 3:
logger.warning(f"Unexpected image format: {array.shape}.")
raise ValueError(f"Unexpected image format: {array.shape}.")
return array
@lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref):
return doc.extract_image(xref)
def get_page_metadata(page):
page_width, page_height = map(rounder, page.mediabox_size)
return {
Info.PAGE_WIDTH: page_width,
Info.PAGE_HEIGHT: page_height,
Info.PAGE_IDX: page.number,
}
rounder = rcompose(round, int)
def tiny(metadatum):
return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4
def clear_caches():
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
xref_to_maybe_image.cache_clear()
atexit.register(clear_caches)