Merge in RR/image-prediction from adjust-falsy-loglevel to master
Squashed commit of the following:
commit 66794acb1a64be6341f98c7c0ce0bc202634a9f4
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date: Fri Feb 10 10:15:41 2023 +0100
replace trace log level by debug
- trace method is not supported by buld-in logging module
330 lines
10 KiB
Python
330 lines
10 KiB
Python
import atexit
|
|
from functools import partial, lru_cache
|
|
from itertools import chain, filterfalse
|
|
from operator import itemgetter, attrgetter
|
|
from typing import List, Union, Iterable, Any
|
|
|
|
import fitz
|
|
import numpy as np
|
|
from PIL import Image
|
|
from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial
|
|
from pymonad.either import Right, Left, Either
|
|
from pymonad.tools import curry, identity
|
|
|
|
from image_prediction.exceptions import InvalidBox, BadXref
|
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
|
from image_prediction.info import Info
|
|
from image_prediction.stitching.stitching import stitch_pairs
|
|
from image_prediction.stitching.utils import validate_box
|
|
from image_prediction.utils import get_logger
|
|
from image_prediction.utils.generic import bottom, left, right, lift, wrap_right
|
|
|
|
logger = get_logger()
|
|
|
|
Doc = fitz.fitz.Document
|
|
Pge = fitz.fitz.Page
|
|
Img = Image.Image
|
|
Pxm = fitz.fitz.Pixmap
|
|
|
|
|
|
class ParsablePDFImageExtractor(ImageExtractor):
|
|
def __init__(self, verbose=False, tolerance=0):
|
|
"""
|
|
|
|
Args:
|
|
verbose: Whether to show progressbar
|
|
tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched
|
|
together
|
|
"""
|
|
self.doc: Union[Doc, None] = None
|
|
self.verbose = verbose
|
|
self.tolerance = tolerance
|
|
|
|
def extract(self, pdf: bytes, page_range: range = None):
|
|
self.doc = fitz.Document(stream=pdf)
|
|
|
|
pages = extract_pages(self.doc, page_range) if page_range else self.doc
|
|
|
|
image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages))
|
|
|
|
yield from image_metadata_pairs
|
|
|
|
def __process_images_on_page(self, page: Pge):
|
|
metadata = extract_valid_metadata(self.doc, page)
|
|
|
|
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
|
|
valid_image_metadata_pairs = lkeep(
|
|
rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image
|
|
)
|
|
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
|
|
|
|
clear_caches()
|
|
|
|
yield from valid_image_metadata_pairs_stitched
|
|
|
|
def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either:
|
|
return metadatum_to_image_metadata_pair(self.doc, metadatum)
|
|
|
|
|
|
def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]:
|
|
page_range = range(page_range.start + 1, page_range.stop + 1)
|
|
pages = map(doc.load_page, page_range)
|
|
|
|
yield from pages
|
|
|
|
|
|
def validate_image(image: Img) -> Either:
|
|
try:
|
|
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
|
|
image.resize((100, 100)).convert("RGB")
|
|
return Right(image)
|
|
except (OSError, Exception):
|
|
logger.warning(f"Invalid image encountered.")
|
|
return Left("Invalid image.")
|
|
|
|
|
|
def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
|
|
return compose(
|
|
list,
|
|
partial(add_alpha_channel_info, doc),
|
|
filter_valid_metadata,
|
|
get_metadata_for_images_on_page,
|
|
)(page)
|
|
|
|
|
|
def take_good_log_bad(item: Either, log_formatter=identity) -> Any:
|
|
return item.either(rpartial(log_error_context, log_formatter), identity)
|
|
|
|
|
|
def format_context(context: dict) -> str:
|
|
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
|
|
|
|
|
|
def log_error_context(context: dict, formatter=identity) -> None:
|
|
logger.warning(f"Skipping bad image. {formatter(context)}")
|
|
return None
|
|
|
|
|
|
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
|
|
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
|
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
|
|
return image_metadata_pair
|
|
|
|
|
|
def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]:
|
|
def add_alpha_value_to_metadatum(metadatum: dict) -> dict:
|
|
alpha = metadatum_to_alpha_value(metadatum)
|
|
return {**metadatum, Info.ALPHA: alpha}
|
|
|
|
xref_to_alpha = partial(has_alpha_channel, doc)
|
|
metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF))
|
|
|
|
yield from map(add_alpha_value_to_metadatum, metadata)
|
|
|
|
|
|
def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
|
yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata)
|
|
|
|
|
|
def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
|
|
metadata = compose(
|
|
partial(add_page_metadata, page),
|
|
lift(get_image_metadata),
|
|
get_image_infos,
|
|
)(page)
|
|
|
|
yield from metadata
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
@curry(2)
|
|
def eith_extract_image(doc: Doc, xref: int) -> Either:
|
|
try:
|
|
return Right(extract_image(doc, xref))
|
|
except BadXref:
|
|
return Left("Bad xref.")
|
|
|
|
|
|
def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either:
|
|
"""Reference: haskell.org/tutorial/monads.html"""
|
|
|
|
def context(value):
|
|
return {"reason": value, "metadata": metadata.either(bottom, identity)}
|
|
|
|
# Explicitly we are doing the following. (1) and (2) are equivalent.
|
|
|
|
# a := Image
|
|
# b := Metadata
|
|
# c := ImageMetadataPair
|
|
# m := Either monad
|
|
|
|
# fmt: off
|
|
# 1)
|
|
# pair: Either = (
|
|
# Right(make_image_metadata_pair) # m (a -> b -> c)
|
|
# .amap(image) # m (a -> b -> c) <*> m a = m (b -> c)
|
|
# .amap(metadata) # m (b -> c) <*> m b = m c
|
|
# )
|
|
|
|
# 2)
|
|
# pair: Either = (
|
|
# image.bind(right(make_image_metadata_pair)) # m a >>= m (a -> b -> c) = m (b -> c)
|
|
# .amap(metadata) # m (b -> c) <*> m b = m c
|
|
# )
|
|
# fmt: on
|
|
|
|
# Syntactic sugar variant with details hidden
|
|
pair: Either = Either.apply(make_image_metadata_pair).to_arguments(image, metadata)
|
|
|
|
return pair.either(left(context), right(identity))
|
|
|
|
|
|
@curry(2)
|
|
def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
|
|
return ImageMetadataPair(image, metadatum)
|
|
|
|
|
|
def extract_image(doc: Doc, xref: int) -> Any:
|
|
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
|
|
|
|
|
|
def pixmap_to_image(pixmap: Pxm) -> Img:
|
|
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
|
|
array = normalize_channels(array)
|
|
return Image.fromarray(array)
|
|
|
|
|
|
def extract_pixmap(doc: Doc, xref: int) -> Pxm:
|
|
try:
|
|
return fitz.Pixmap(doc, xref)
|
|
except ValueError as err:
|
|
msg = f"Cross reference {xref} is invalid, skipping extraction."
|
|
logger.error(err)
|
|
logger.debug(msg)
|
|
raise BadXref(msg) from err
|
|
|
|
|
|
def has_alpha_channel(doc: Doc, xref: int) -> bool:
|
|
|
|
_get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc)
|
|
_extract_pixmap = wrap_right(extract_pixmap)(doc)
|
|
|
|
def get_soft_mask_reference(cross_reference: int) -> Either:
|
|
def error(value) -> str:
|
|
return f"Invalid soft mask {value} for cross reference {cross_reference}."
|
|
|
|
logger.debug(f"Getting soft mask handle for cross reference {cross_reference}.")
|
|
pass_on_if_not_none = iffy(notnone, right(identity), left(error))
|
|
return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none)
|
|
|
|
def mask_exists(soft_mask_reference: int) -> Either:
|
|
logger.debug(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.")
|
|
return _get_image_handle(soft_mask_reference).then(notnone)
|
|
|
|
def image_has_alpha_channel(reference: int) -> Either:
|
|
logger.debug(f"Checking if image with reference {reference} has alpha channel.")
|
|
return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool)
|
|
|
|
logger.debug(f"Checking if image with cross reference {xref} has alpha channel.")
|
|
|
|
cross_reference = Right(xref)
|
|
soft_mask_reference = cross_reference.bind(get_soft_mask_reference)
|
|
|
|
return any(
|
|
take_good_log_bad(reference.bind(check))
|
|
for reference, check in [
|
|
(soft_mask_reference, mask_exists),
|
|
(soft_mask_reference, image_has_alpha_channel),
|
|
(cross_reference, image_has_alpha_channel),
|
|
]
|
|
)
|
|
|
|
|
|
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
|
|
yield from filterfalse(tiny, metadata)
|
|
|
|
|
|
def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
|
def __validate_box(box):
|
|
try:
|
|
return validate_box(box)
|
|
except InvalidBox as err:
|
|
logger.debug(f"Dropping invalid metadatum, reason: {err}")
|
|
|
|
yield from keep(__validate_box, metadata)
|
|
|
|
|
|
def get_image_metadata(image_info: dict) -> dict:
|
|
|
|
xref, coords = itemgetter("xref", "bbox")(image_info)
|
|
x1, y1, x2, y2 = map(rounder, coords)
|
|
|
|
width = abs(x2 - x1)
|
|
height = abs(y2 - y1)
|
|
|
|
return {
|
|
Info.WIDTH: width,
|
|
Info.HEIGHT: height,
|
|
Info.X1: x1,
|
|
Info.X2: x2,
|
|
Info.Y1: y1,
|
|
Info.Y2: y2,
|
|
Info.XREF: xref,
|
|
}
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_image_infos(page: Pge) -> List[dict]:
|
|
return page.get_image_info(xrefs=True)
|
|
|
|
|
|
def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]:
|
|
yield from map(partial(merge, get_page_metadata(page)), metadata)
|
|
|
|
|
|
def normalize_channels(array: np.ndarray):
|
|
if not array.ndim == 3:
|
|
array = np.expand_dims(array, axis=-1)
|
|
|
|
if array.shape[-1] == 4:
|
|
array = array[..., :3]
|
|
elif array.shape[-1] == 1:
|
|
array = np.concatenate([array, array, array], axis=-1)
|
|
elif array.shape[-1] != 3:
|
|
logger.warning(f"Unexpected image format: {array.shape}.")
|
|
raise ValueError(f"Unexpected image format: {array.shape}.")
|
|
|
|
return array
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]:
|
|
return doc.extract_image(xref)
|
|
|
|
|
|
def get_page_metadata(page: Pge) -> dict:
|
|
page_width, page_height = map(rounder, page.mediabox_size)
|
|
|
|
return {
|
|
Info.PAGE_WIDTH: page_width,
|
|
Info.PAGE_HEIGHT: page_height,
|
|
Info.PAGE_IDX: page.number,
|
|
}
|
|
|
|
|
|
rounder = rcompose(round, int)
|
|
|
|
|
|
def tiny(metadatum: dict) -> bool:
|
|
return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4
|
|
|
|
|
|
def clear_caches() -> None:
|
|
get_image_infos.cache_clear()
|
|
get_image_handle.cache_clear()
|
|
eith_extract_image.cache_clear()
|
|
|
|
|
|
atexit.register(clear_caches)
|