Julius Unverfehrt ea301b4df2 Pull request #40: replace trace log level by debug
Merge in RR/image-prediction from adjust-falsy-loglevel to master

Squashed commit of the following:

commit 66794acb1a64be6341f98c7c0ce0bc202634a9f4
Author: Julius Unverfehrt <julius.unverfehrt@iqser.com>
Date:   Fri Feb 10 10:15:41 2023 +0100

    replace trace log level by debug

    - trace method is not supported by buld-in logging module
2023-02-10 10:18:38 +01:00

330 lines
10 KiB
Python

import atexit
from functools import partial, lru_cache
from itertools import chain, filterfalse
from operator import itemgetter, attrgetter
from typing import List, Union, Iterable, Any
import fitz
import numpy as np
from PIL import Image
from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial
from pymonad.either import Right, Left, Either
from pymonad.tools import curry, identity
from image_prediction.exceptions import InvalidBox, BadXref
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import validate_box
from image_prediction.utils import get_logger
from image_prediction.utils.generic import bottom, left, right, lift, wrap_right
logger = get_logger()
Doc = fitz.fitz.Document
Pge = fitz.fitz.Page
Img = Image.Image
Pxm = fitz.fitz.Pixmap
class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self, verbose=False, tolerance=0):
"""
Args:
verbose: Whether to show progressbar
tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched
together
"""
self.doc: Union[Doc, None] = None
self.verbose = verbose
self.tolerance = tolerance
def extract(self, pdf: bytes, page_range: range = None):
self.doc = fitz.Document(stream=pdf)
pages = extract_pages(self.doc, page_range) if page_range else self.doc
image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages))
yield from image_metadata_pairs
def __process_images_on_page(self, page: Pge):
metadata = extract_valid_metadata(self.doc, page)
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
valid_image_metadata_pairs = lkeep(
rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image
)
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
clear_caches()
yield from valid_image_metadata_pairs_stitched
def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either:
return metadatum_to_image_metadata_pair(self.doc, metadatum)
def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]:
page_range = range(page_range.start + 1, page_range.stop + 1)
pages = map(doc.load_page, page_range)
yield from pages
def validate_image(image: Img) -> Either:
try:
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
image.resize((100, 100)).convert("RGB")
return Right(image)
except (OSError, Exception):
logger.warning(f"Invalid image encountered.")
return Left("Invalid image.")
def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
return compose(
list,
partial(add_alpha_channel_info, doc),
filter_valid_metadata,
get_metadata_for_images_on_page,
)(page)
def take_good_log_bad(item: Either, log_formatter=identity) -> Any:
return item.either(rpartial(log_error_context, log_formatter), identity)
def format_context(context: dict) -> str:
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
def log_error_context(context: dict, formatter=identity) -> None:
logger.warning(f"Skipping bad image. {formatter(context)}")
return None
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
return image_metadata_pair
def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]:
def add_alpha_value_to_metadatum(metadatum: dict) -> dict:
alpha = metadatum_to_alpha_value(metadatum)
return {**metadatum, Info.ALPHA: alpha}
xref_to_alpha = partial(has_alpha_channel, doc)
metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF))
yield from map(add_alpha_value_to_metadatum, metadata)
def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata)
def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
metadata = compose(
partial(add_page_metadata, page),
lift(get_image_metadata),
get_image_infos,
)(page)
yield from metadata
@lru_cache(maxsize=None)
@curry(2)
def eith_extract_image(doc: Doc, xref: int) -> Either:
try:
return Right(extract_image(doc, xref))
except BadXref:
return Left("Bad xref.")
def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either:
"""Reference: haskell.org/tutorial/monads.html"""
def context(value):
return {"reason": value, "metadata": metadata.either(bottom, identity)}
# Explicitly we are doing the following. (1) and (2) are equivalent.
# a := Image
# b := Metadata
# c := ImageMetadataPair
# m := Either monad
# fmt: off
# 1)
# pair: Either = (
# Right(make_image_metadata_pair) # m (a -> b -> c)
# .amap(image) # m (a -> b -> c) <*> m a = m (b -> c)
# .amap(metadata) # m (b -> c) <*> m b = m c
# )
# 2)
# pair: Either = (
# image.bind(right(make_image_metadata_pair)) # m a >>= m (a -> b -> c) = m (b -> c)
# .amap(metadata) # m (b -> c) <*> m b = m c
# )
# fmt: on
# Syntactic sugar variant with details hidden
pair: Either = Either.apply(make_image_metadata_pair).to_arguments(image, metadata)
return pair.either(left(context), right(identity))
@curry(2)
def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
return ImageMetadataPair(image, metadatum)
def extract_image(doc: Doc, xref: int) -> Any:
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
def pixmap_to_image(pixmap: Pxm) -> Img:
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
array = normalize_channels(array)
return Image.fromarray(array)
def extract_pixmap(doc: Doc, xref: int) -> Pxm:
try:
return fitz.Pixmap(doc, xref)
except ValueError as err:
msg = f"Cross reference {xref} is invalid, skipping extraction."
logger.error(err)
logger.debug(msg)
raise BadXref(msg) from err
def has_alpha_channel(doc: Doc, xref: int) -> bool:
_get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc)
_extract_pixmap = wrap_right(extract_pixmap)(doc)
def get_soft_mask_reference(cross_reference: int) -> Either:
def error(value) -> str:
return f"Invalid soft mask {value} for cross reference {cross_reference}."
logger.debug(f"Getting soft mask handle for cross reference {cross_reference}.")
pass_on_if_not_none = iffy(notnone, right(identity), left(error))
return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none)
def mask_exists(soft_mask_reference: int) -> Either:
logger.debug(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.")
return _get_image_handle(soft_mask_reference).then(notnone)
def image_has_alpha_channel(reference: int) -> Either:
logger.debug(f"Checking if image with reference {reference} has alpha channel.")
return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool)
logger.debug(f"Checking if image with cross reference {xref} has alpha channel.")
cross_reference = Right(xref)
soft_mask_reference = cross_reference.bind(get_soft_mask_reference)
return any(
take_good_log_bad(reference.bind(check))
for reference, check in [
(soft_mask_reference, mask_exists),
(soft_mask_reference, image_has_alpha_channel),
(cross_reference, image_has_alpha_channel),
]
)
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
yield from filterfalse(tiny, metadata)
def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
def __validate_box(box):
try:
return validate_box(box)
except InvalidBox as err:
logger.debug(f"Dropping invalid metadatum, reason: {err}")
yield from keep(__validate_box, metadata)
def get_image_metadata(image_info: dict) -> dict:
xref, coords = itemgetter("xref", "bbox")(image_info)
x1, y1, x2, y2 = map(rounder, coords)
width = abs(x2 - x1)
height = abs(y2 - y1)
return {
Info.WIDTH: width,
Info.HEIGHT: height,
Info.X1: x1,
Info.X2: x2,
Info.Y1: y1,
Info.Y2: y2,
Info.XREF: xref,
}
@lru_cache(maxsize=None)
def get_image_infos(page: Pge) -> List[dict]:
return page.get_image_info(xrefs=True)
def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]:
yield from map(partial(merge, get_page_metadata(page)), metadata)
def normalize_channels(array: np.ndarray):
if not array.ndim == 3:
array = np.expand_dims(array, axis=-1)
if array.shape[-1] == 4:
array = array[..., :3]
elif array.shape[-1] == 1:
array = np.concatenate([array, array, array], axis=-1)
elif array.shape[-1] != 3:
logger.warning(f"Unexpected image format: {array.shape}.")
raise ValueError(f"Unexpected image format: {array.shape}.")
return array
@lru_cache(maxsize=None)
def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]:
return doc.extract_image(xref)
def get_page_metadata(page: Pge) -> dict:
page_width, page_height = map(rounder, page.mediabox_size)
return {
Info.PAGE_WIDTH: page_width,
Info.PAGE_HEIGHT: page_height,
Info.PAGE_IDX: page.number,
}
rounder = rcompose(round, int)
def tiny(metadatum: dict) -> bool:
return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4
def clear_caches() -> None:
get_image_infos.cache_clear()
get_image_handle.cache_clear()
eith_extract_image.cache_clear()
atexit.register(clear_caches)