Replace Maybe with Either to allow passing on error information or metadata which otherwise get sucked up by Nothing.
253 lines
7.4 KiB
Python
253 lines
7.4 KiB
Python
import atexit
|
|
from _operator import itemgetter
|
|
from functools import partial, lru_cache
|
|
from itertools import chain, filterfalse
|
|
from operator import itemgetter
|
|
from typing import List, Union
|
|
|
|
import fitz
|
|
import numpy as np
|
|
from PIL import Image
|
|
from funcy import merge, compose, rcompose, keep, lfilter
|
|
from pymonad.either import Right, Left, Either
|
|
from pymonad.tools import curry, identity
|
|
|
|
from image_prediction.exceptions import InvalidBox, BadXref
|
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
|
from image_prediction.info import Info
|
|
from image_prediction.stitching.stitching import stitch_pairs
|
|
from image_prediction.stitching.utils import validate_box
|
|
from image_prediction.utils import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class ParsablePDFImageExtractor(ImageExtractor):
|
|
def __init__(self, verbose=False, tolerance=0):
|
|
"""
|
|
|
|
Args:
|
|
verbose: Whether to show progressbar
|
|
tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched
|
|
together
|
|
"""
|
|
self.doc: Union[fitz.fitz.Document, None] = None
|
|
self.verbose = verbose
|
|
self.tolerance = tolerance
|
|
|
|
def extract(self, pdf: bytes, page_range: range = None):
|
|
self.doc = fitz.Document(stream=pdf)
|
|
|
|
pages = extract_pages(self.doc, page_range) if page_range else self.doc
|
|
|
|
image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages))
|
|
|
|
yield from image_metadata_pairs
|
|
|
|
def __process_images_on_page(self, page: fitz.fitz.Page):
|
|
metadata = extract_valid_metadata(self.doc, page)
|
|
|
|
maybe_image_metadata_pairs = map(partial(metadatum_to_image_metadata_pair, self.doc), metadata)
|
|
image_metadata_pairs = keep(right, maybe_image_metadata_pairs)
|
|
clear_caches()
|
|
|
|
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
|
|
|
|
yield from image_metadata_pairs
|
|
|
|
|
|
def right(pair: Either):
|
|
if pair.is_right():
|
|
return pair.either(identity, identity)
|
|
logger.warning(f"Skipping bad image. reason: {pair.either(identity, identity)}")
|
|
|
|
|
|
def extract_pages(doc, page_range):
|
|
page_range = range(page_range.start + 1, page_range.stop + 1)
|
|
pages = map(doc.load_page, page_range)
|
|
|
|
yield from pages
|
|
|
|
|
|
def validate_image(image: Image.Image) -> Either:
|
|
try:
|
|
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
|
|
image.resize((100, 100)).convert("RGB")
|
|
return Right(image)
|
|
except (OSError, Exception):
|
|
logger.warning(f"Invalid image encountered.")
|
|
return Left("Invalid image.")
|
|
|
|
|
|
def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page):
|
|
return compose(
|
|
list,
|
|
partial(add_alpha_channel_info, doc),
|
|
filter_valid_metadata,
|
|
get_metadata_for_images_on_page,
|
|
)(page)
|
|
|
|
|
|
def metadatum_to_image_metadata_pair(doc, metadatum: dict) -> Either:
|
|
maybe_image = xref_to_maybe_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
|
maybe_image_metadata_pair = make_maybe_image_metadata_pair(maybe_image, Right(metadatum))
|
|
return maybe_image_metadata_pair
|
|
|
|
|
|
def add_alpha_channel_info(doc, metadata):
|
|
def add_alpha_value_to_metadatum(metadatum):
|
|
alpha = metadatum_to_alpha_value(metadatum)
|
|
return {**metadatum, Info.ALPHA: alpha}
|
|
|
|
xref_to_alpha = partial(has_alpha_channel, doc)
|
|
metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF))
|
|
|
|
yield from map(add_alpha_value_to_metadatum, metadata)
|
|
|
|
|
|
def filter_valid_metadata(metadata):
|
|
yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata)
|
|
|
|
|
|
def get_metadata_for_images_on_page(page: fitz.Page):
|
|
metadata = map(get_image_metadata, get_image_infos(page))
|
|
metadata = add_page_metadata(page, metadata)
|
|
|
|
yield from metadata
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def xref_to_maybe_image(doc, xref) -> Either:
|
|
try:
|
|
return Right(extract_image(doc, xref))
|
|
except BadXref:
|
|
return Left("Bad xref.")
|
|
|
|
|
|
def make_maybe_image_metadata_pair(image: Either, metadata: Either):
|
|
# haskell.org/tutorial/monads.html
|
|
# (>>) :: m a -> m b -> m b
|
|
return Right(make_image_metadata_pair).amap(image).amap(metadata)
|
|
# TODO: Somehow metadata needs to be added to Lefts for logging the reference to the invalid image
|
|
|
|
|
|
@curry(2)
|
|
def make_image_metadata_pair(image: Image.Image, metadatum: dict) -> ImageMetadataPair:
|
|
return ImageMetadataPair(image, metadatum)
|
|
|
|
|
|
def extract_image(doc, xref) -> Image.Image:
|
|
try:
|
|
pixmap = fitz.Pixmap(doc, xref)
|
|
except ValueError as err:
|
|
msg = f"Xref {xref} is invalid, skipping extraction."
|
|
logger.debug(msg)
|
|
raise BadXref(msg) from err
|
|
|
|
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
|
|
array = normalize_channels(array)
|
|
return Image.fromarray(array)
|
|
|
|
|
|
def has_alpha_channel(doc, xref):
|
|
|
|
maybe_image = load_image_handle_from_xref(doc, xref)
|
|
maybe_smask = maybe_image["smask"] if maybe_image else None
|
|
|
|
if maybe_smask:
|
|
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
|
|
else:
|
|
try:
|
|
return bool(fitz.Pixmap(doc, xref).alpha)
|
|
except ValueError:
|
|
logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '<no title>')}.")
|
|
return False
|
|
|
|
|
|
def filter_out_tiny_images(metadata):
|
|
yield from filterfalse(tiny, metadata)
|
|
|
|
|
|
def filter_out_invalid_metadata(metadata):
|
|
def __validate_box(box):
|
|
try:
|
|
return validate_box(box)
|
|
except InvalidBox as err:
|
|
logger.debug(f"Dropping invalid metadatum, reason: {err}")
|
|
|
|
yield from keep(__validate_box, metadata)
|
|
|
|
|
|
def get_image_metadata(image_info):
|
|
|
|
xref, coords = itemgetter("xref", "bbox")(image_info)
|
|
x1, y1, x2, y2 = map(rounder, coords)
|
|
|
|
width = abs(x2 - x1)
|
|
height = abs(y2 - y1)
|
|
|
|
return {
|
|
Info.WIDTH: width,
|
|
Info.HEIGHT: height,
|
|
Info.X1: x1,
|
|
Info.X2: x2,
|
|
Info.Y1: y1,
|
|
Info.Y2: y2,
|
|
Info.XREF: xref,
|
|
}
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_image_infos(page: fitz.Page) -> List[dict]:
|
|
return page.get_image_info(xrefs=True)
|
|
|
|
|
|
def add_page_metadata(page, metadata):
|
|
yield from map(partial(merge, get_page_metadata(page)), metadata)
|
|
|
|
|
|
def normalize_channels(array: np.ndarray):
|
|
if not array.ndim == 3:
|
|
array = np.expand_dims(array, axis=-1)
|
|
|
|
if array.shape[-1] == 4:
|
|
array = array[..., :3]
|
|
elif array.shape[-1] == 1:
|
|
array = np.concatenate([array, array, array], axis=-1)
|
|
elif array.shape[-1] != 3:
|
|
logger.warning(f"Unexpected image format: {array.shape}.")
|
|
raise ValueError(f"Unexpected image format: {array.shape}.")
|
|
|
|
return array
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def load_image_handle_from_xref(doc, xref):
|
|
return doc.extract_image(xref)
|
|
|
|
|
|
def get_page_metadata(page):
|
|
page_width, page_height = map(rounder, page.mediabox_size)
|
|
|
|
return {
|
|
Info.PAGE_WIDTH: page_width,
|
|
Info.PAGE_HEIGHT: page_height,
|
|
Info.PAGE_IDX: page.number,
|
|
}
|
|
|
|
|
|
rounder = rcompose(round, int)
|
|
|
|
|
|
def tiny(metadatum):
|
|
return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4
|
|
|
|
|
|
def clear_caches():
|
|
get_image_infos.cache_clear()
|
|
load_image_handle_from_xref.cache_clear()
|
|
xref_to_maybe_image.cache_clear()
|
|
|
|
|
|
atexit.register(clear_caches)
|