Pull request #41: RED-6189 bugfix
Merge in RR/image-prediction from RED-6189-bugfix to master * commit '79455f0dd6da835ef2261393c5a57ba8ef2550ab': (25 commits) revert refactoring changes replace image extraction logic final introduce normalizing function for image extraction refactoring adjust behavior of filtering of invalid images add log in callback to diplay which file is processed add ad hoc logic for bad xref handling beautify beautify implement ad hoc channel count detection for new image extraction improve performance refactor scanned page filtering refactor scanned page filtering WIP refactor scanned page filtering WIP refactor scanned page filtering WIP refactor scanned page filtering WIP refactor scanned page filtering WIP refactor scanned page filtering WIP refactor scanned page filtering WIP refactor ...
This commit is contained in:
commit
463f4da92b
@ -1,32 +1,29 @@
|
||||
import atexit
|
||||
import json
|
||||
import traceback
|
||||
from _operator import itemgetter
|
||||
from functools import partial, lru_cache
|
||||
from itertools import chain, filterfalse
|
||||
from operator import itemgetter, attrgetter
|
||||
from typing import List, Union, Iterable, Any
|
||||
from itertools import chain, starmap, filterfalse
|
||||
from operator import itemgetter, truth
|
||||
from typing import Iterable, Iterator, List, Union
|
||||
|
||||
import fitz
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial
|
||||
from pymonad.either import Right, Left, Either
|
||||
from pymonad.tools import curry, identity
|
||||
from funcy import merge, pluck, compose, rcompose, remove, keep
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.exceptions import InvalidBox, BadXref
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.stitching.stitching import stitch_pairs
|
||||
from image_prediction.stitching.utils import validate_box
|
||||
from image_prediction.transformer.transformers.response import compute_geometric_quotient
|
||||
from image_prediction.utils import get_logger
|
||||
from image_prediction.utils.generic import bottom, left, right, lift, wrap_right
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
Doc = fitz.fitz.Document
|
||||
Pge = fitz.fitz.Page
|
||||
Img = Image.Image
|
||||
Pxm = fitz.fitz.Pixmap
|
||||
|
||||
|
||||
class ParsablePDFImageExtractor(ImageExtractor):
|
||||
def __init__(self, verbose=False, tolerance=0):
|
||||
@ -37,7 +34,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched
|
||||
together
|
||||
"""
|
||||
self.doc: Union[Doc, None] = None
|
||||
self.doc: fitz.fitz.Document = None
|
||||
self.verbose = verbose
|
||||
self.tolerance = tolerance
|
||||
|
||||
@ -50,41 +47,51 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
|
||||
yield from image_metadata_pairs
|
||||
|
||||
def __process_images_on_page(self, page: Pge):
|
||||
def __process_images_on_page(self, page: fitz.fitz.Page):
|
||||
metadata = extract_valid_metadata(self.doc, page)
|
||||
|
||||
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
|
||||
valid_image_metadata_pairs = lkeep(
|
||||
rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image
|
||||
)
|
||||
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
|
||||
images = get_images_on_page(self.doc, metadata)
|
||||
|
||||
clear_caches()
|
||||
|
||||
yield from valid_image_metadata_pairs_stitched
|
||||
image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata)))
|
||||
# TODO: In the future, consider to introduce an image validator as a pipeline component rather than doing the
|
||||
# validation here. Invalid images can then be split into a different stream and joined with the intact images
|
||||
# again for the formatting step.
|
||||
image_metadata_pairs = self.__filter_valid_images(image_metadata_pairs)
|
||||
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
|
||||
|
||||
def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either:
|
||||
return metadatum_to_image_metadata_pair(self.doc, metadatum)
|
||||
yield from image_metadata_pairs
|
||||
|
||||
@staticmethod
|
||||
def __filter_valid_images(image_metadata_pairs: Iterable[ImageMetadataPair]) -> Iterator[ImageMetadataPair]:
|
||||
def validate(image: Image.Image, metadata: dict):
|
||||
try:
|
||||
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
|
||||
image.resize((100, 100)).convert("RGB")
|
||||
return ImageMetadataPair(image, metadata)
|
||||
except (OSError, Exception) as err:
|
||||
metadata = json.dumps(EnumFormatter()(metadata), indent=2)
|
||||
logger.warning(f"Invalid image encountered. Image metadata:\n{metadata}\n\n{traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
return filter(truth, starmap(validate, image_metadata_pairs))
|
||||
|
||||
|
||||
def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]:
|
||||
def extract_pages(doc, page_range):
|
||||
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||
pages = map(doc.load_page, page_range)
|
||||
|
||||
yield from pages
|
||||
|
||||
|
||||
def validate_image(image: Img) -> Either:
|
||||
try:
|
||||
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
|
||||
image.resize((100, 100)).convert("RGB")
|
||||
return Right(image)
|
||||
except (OSError, Exception):
|
||||
logger.warning(f"Invalid image encountered.")
|
||||
return Left("Invalid image.")
|
||||
def get_images_on_page(doc, metadata):
|
||||
xrefs = pluck(Info.XREF, metadata)
|
||||
images = map(partial(xref_to_image, doc), xrefs)
|
||||
|
||||
yield from images
|
||||
|
||||
|
||||
def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
|
||||
def extract_valid_metadata(doc: fitz.fitz.Document, page: fitz.fitz.Page):
|
||||
return compose(
|
||||
list,
|
||||
partial(add_alpha_channel_info, doc),
|
||||
@ -93,159 +100,25 @@ def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
|
||||
)(page)
|
||||
|
||||
|
||||
def take_good_log_bad(item: Either, log_formatter=identity) -> Any:
|
||||
return item.either(rpartial(log_error_context, log_formatter), identity)
|
||||
|
||||
|
||||
def format_context(context: dict) -> str:
|
||||
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
|
||||
|
||||
|
||||
def log_error_context(context: dict, formatter=identity) -> None:
|
||||
logger.warning(f"Skipping bad image. {formatter(context)}")
|
||||
return None
|
||||
|
||||
|
||||
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
|
||||
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
||||
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
|
||||
return image_metadata_pair
|
||||
|
||||
|
||||
def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
def add_alpha_value_to_metadatum(metadatum: dict) -> dict:
|
||||
alpha = metadatum_to_alpha_value(metadatum)
|
||||
return {**metadatum, Info.ALPHA: alpha}
|
||||
|
||||
xref_to_alpha = partial(has_alpha_channel, doc)
|
||||
metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF))
|
||||
|
||||
yield from map(add_alpha_value_to_metadatum, metadata)
|
||||
|
||||
|
||||
def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata)
|
||||
|
||||
|
||||
def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
|
||||
metadata = compose(
|
||||
partial(add_page_metadata, page),
|
||||
lift(get_image_metadata),
|
||||
get_image_infos,
|
||||
)(page)
|
||||
def get_metadata_for_images_on_page(page: fitz.Page):
|
||||
metadata = map(get_image_metadata, get_image_infos(page))
|
||||
metadata = add_page_metadata(page, metadata)
|
||||
|
||||
yield from metadata
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@curry(2)
|
||||
def eith_extract_image(doc: Doc, xref: int) -> Either:
|
||||
try:
|
||||
return Right(extract_image(doc, xref))
|
||||
except BadXref:
|
||||
return Left("Bad xref.")
|
||||
def filter_valid_metadata(metadata):
|
||||
yield from compose(
|
||||
# TODO: Disabled for now, since atm since the backend needs atm the metadata and the hash of every image, even
|
||||
# scanned pages. In the future, this should be resolved differently, e.g. by filtering all page-sized images
|
||||
# and giving the user the ability to reclassify false positives with a separate call.
|
||||
# filter_out_page_sized_images,
|
||||
filter_out_tiny_images,
|
||||
filter_out_invalid_metadata,
|
||||
)(metadata)
|
||||
|
||||
|
||||
def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either:
|
||||
"""Reference: haskell.org/tutorial/monads.html"""
|
||||
|
||||
def context(value):
|
||||
return {"reason": value, "metadata": metadata.either(bottom, identity)}
|
||||
|
||||
# Explicitly we are doing the following. (1) and (2) are equivalent.
|
||||
|
||||
# a := Image
|
||||
# b := Metadata
|
||||
# c := ImageMetadataPair
|
||||
# m := Either monad
|
||||
|
||||
# fmt: off
|
||||
# 1)
|
||||
# pair: Either = (
|
||||
# Right(make_image_metadata_pair) # m (a -> b -> c)
|
||||
# .amap(image) # m (a -> b -> c) <*> m a = m (b -> c)
|
||||
# .amap(metadata) # m (b -> c) <*> m b = m c
|
||||
# )
|
||||
|
||||
# 2)
|
||||
# pair: Either = (
|
||||
# image.bind(right(make_image_metadata_pair)) # m a >>= m (a -> b -> c) = m (b -> c)
|
||||
# .amap(metadata) # m (b -> c) <*> m b = m c
|
||||
# )
|
||||
# fmt: on
|
||||
|
||||
# Syntactic sugar variant with details hidden
|
||||
pair: Either = Either.apply(make_image_metadata_pair).to_arguments(image, metadata)
|
||||
|
||||
return pair.either(left(context), right(identity))
|
||||
|
||||
|
||||
@curry(2)
|
||||
def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
|
||||
return ImageMetadataPair(image, metadatum)
|
||||
|
||||
|
||||
def extract_image(doc: Doc, xref: int) -> Any:
|
||||
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
|
||||
|
||||
|
||||
def pixmap_to_image(pixmap: Pxm) -> Img:
|
||||
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
|
||||
array = normalize_channels(array)
|
||||
return Image.fromarray(array)
|
||||
|
||||
|
||||
def extract_pixmap(doc: Doc, xref: int) -> Pxm:
|
||||
try:
|
||||
return fitz.Pixmap(doc, xref)
|
||||
except ValueError as err:
|
||||
msg = f"Cross reference {xref} is invalid, skipping extraction."
|
||||
logger.error(err)
|
||||
logger.debug(msg)
|
||||
raise BadXref(msg) from err
|
||||
|
||||
|
||||
def has_alpha_channel(doc: Doc, xref: int) -> bool:
|
||||
|
||||
_get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc)
|
||||
_extract_pixmap = wrap_right(extract_pixmap)(doc)
|
||||
|
||||
def get_soft_mask_reference(cross_reference: int) -> Either:
|
||||
def error(value) -> str:
|
||||
return f"Invalid soft mask {value} for cross reference {cross_reference}."
|
||||
|
||||
logger.debug(f"Getting soft mask handle for cross reference {cross_reference}.")
|
||||
pass_on_if_not_none = iffy(notnone, right(identity), left(error))
|
||||
return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none)
|
||||
|
||||
def mask_exists(soft_mask_reference: int) -> Either:
|
||||
logger.debug(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.")
|
||||
return _get_image_handle(soft_mask_reference).then(notnone)
|
||||
|
||||
def image_has_alpha_channel(reference: int) -> Either:
|
||||
logger.debug(f"Checking if image with reference {reference} has alpha channel.")
|
||||
return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool)
|
||||
|
||||
logger.debug(f"Checking if image with cross reference {xref} has alpha channel.")
|
||||
|
||||
cross_reference = Right(xref)
|
||||
soft_mask_reference = cross_reference.bind(get_soft_mask_reference)
|
||||
|
||||
return any(
|
||||
take_good_log_bad(reference.bind(check))
|
||||
for reference, check in [
|
||||
(soft_mask_reference, mask_exists),
|
||||
(soft_mask_reference, image_has_alpha_channel),
|
||||
(cross_reference, image_has_alpha_channel),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
yield from filterfalse(tiny, metadata)
|
||||
|
||||
|
||||
def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
def filter_out_invalid_metadata(metadata):
|
||||
def __validate_box(box):
|
||||
try:
|
||||
return validate_box(box)
|
||||
@ -255,7 +128,50 @@ def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
yield from keep(__validate_box, metadata)
|
||||
|
||||
|
||||
def get_image_metadata(image_info: dict) -> dict:
|
||||
def filter_out_page_sized_images(metadata):
|
||||
yield from remove(breaches_image_to_page_quotient, metadata)
|
||||
|
||||
|
||||
def filter_out_tiny_images(metadata):
|
||||
yield from filterfalse(tiny, metadata)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_image_infos(page: fitz.Page) -> List[dict]:
|
||||
return page.get_image_info(xrefs=True)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def xref_to_image(doc, xref) -> Union[Image.Image, None]:
|
||||
# NOTE: image extraction is done via pixmap to array, as this method is twice as fast as extraction via bytestream
|
||||
try:
|
||||
pixmap = fitz.Pixmap(doc, xref)
|
||||
array = convert_pixmap_to_array(pixmap)
|
||||
return Image.fromarray(array)
|
||||
except ValueError:
|
||||
logger.debug(f"Xref {xref} is invalid, skipping extraction ...")
|
||||
return
|
||||
|
||||
|
||||
def convert_pixmap_to_array(pixmap: fitz.fitz.Pixmap):
|
||||
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape(pixmap.h, pixmap.w, pixmap.n)
|
||||
array = _normalize_channels(array)
|
||||
return array
|
||||
|
||||
|
||||
def _normalize_channels(array: np.ndarray):
|
||||
if array.shape[-1] == 1:
|
||||
array = array[:, :, 0]
|
||||
elif array.shape[-1] == 4:
|
||||
array = array[..., :3]
|
||||
elif array.shape[-1] != 3:
|
||||
logger.warning(f"Unexpected image format: {array.shape}.")
|
||||
raise ValueError(f"Unexpected image format: {array.shape}.")
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def get_image_metadata(image_info):
|
||||
|
||||
xref, coords = itemgetter("xref", "bbox")(image_info)
|
||||
x1, y1, x2, y2 = map(rounder, coords)
|
||||
@ -274,36 +190,30 @@ def get_image_metadata(image_info: dict) -> dict:
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_image_infos(page: Pge) -> List[dict]:
|
||||
return page.get_image_info(xrefs=True)
|
||||
|
||||
|
||||
def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
def add_page_metadata(page, metadata):
|
||||
yield from map(partial(merge, get_page_metadata(page)), metadata)
|
||||
|
||||
|
||||
def normalize_channels(array: np.ndarray):
|
||||
if not array.ndim == 3:
|
||||
array = np.expand_dims(array, axis=-1)
|
||||
def add_alpha_channel_info(doc, metadata):
|
||||
def add_alpha_value_to_metadatum(metadatum):
|
||||
alpha = metadatum_to_alpha_value(metadatum)
|
||||
return {**metadatum, Info.ALPHA: alpha}
|
||||
|
||||
if array.shape[-1] == 4:
|
||||
array = array[..., :3]
|
||||
elif array.shape[-1] == 1:
|
||||
array = np.concatenate([array, array, array], axis=-1)
|
||||
elif array.shape[-1] != 3:
|
||||
logger.warning(f"Unexpected image format: {array.shape}.")
|
||||
raise ValueError(f"Unexpected image format: {array.shape}.")
|
||||
xref_to_alpha = partial(has_alpha_channel, doc)
|
||||
metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF))
|
||||
|
||||
return array
|
||||
yield from map(add_alpha_value_to_metadatum, metadata)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]:
|
||||
def load_image_handle_from_xref(doc, xref):
|
||||
return doc.extract_image(xref)
|
||||
|
||||
|
||||
def get_page_metadata(page: Pge) -> dict:
|
||||
rounder = rcompose(round, int)
|
||||
|
||||
|
||||
def get_page_metadata(page):
|
||||
page_width, page_height = map(rounder, page.mediabox_size)
|
||||
|
||||
return {
|
||||
@ -313,17 +223,38 @@ def get_page_metadata(page: Pge) -> dict:
|
||||
}
|
||||
|
||||
|
||||
rounder = rcompose(round, int)
|
||||
def has_alpha_channel(doc, xref):
|
||||
|
||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||
maybe_smask = maybe_image["smask"] if maybe_image else None
|
||||
|
||||
if maybe_smask:
|
||||
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
|
||||
else:
|
||||
try:
|
||||
return bool(fitz.Pixmap(doc, xref).alpha)
|
||||
except ValueError:
|
||||
logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '<no title>')}.")
|
||||
return False
|
||||
|
||||
|
||||
def tiny(metadatum: dict) -> bool:
|
||||
return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4
|
||||
def tiny(metadata):
|
||||
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
|
||||
|
||||
|
||||
def clear_caches() -> None:
|
||||
def clear_caches():
|
||||
get_image_infos.cache_clear()
|
||||
get_image_handle.cache_clear()
|
||||
eith_extract_image.cache_clear()
|
||||
load_image_handle_from_xref.cache_clear()
|
||||
xref_to_image.cache_clear()
|
||||
|
||||
|
||||
atexit.register(clear_caches)
|
||||
|
||||
|
||||
def breaches_image_to_page_quotient(metadatum):
|
||||
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
|
||||
Info.PAGE_WIDTH, Info.PAGE_HEIGHT, Info.X1, Info.X2, Info.Y1, Info.Y2, Info.WIDTH, Info.HEIGHT
|
||||
)(metadatum)
|
||||
geometric_quotient = compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1)
|
||||
quotient_breached = bool(geometric_quotient > CONFIG.filters.image_to_page_quotient.max)
|
||||
return quotient_breached
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import random
|
||||
from operator import itemgetter
|
||||
|
||||
@ -8,23 +7,18 @@ import pytest
|
||||
from PIL import Image
|
||||
from funcy import first, rest
|
||||
|
||||
from image_prediction.exceptions import BadXref
|
||||
from image_prediction.extraction import extract_images_from_pdf
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.parsable import (
|
||||
extract_pages,
|
||||
has_alpha_channel,
|
||||
get_image_infos,
|
||||
ParsablePDFImageExtractor,
|
||||
extract_valid_metadata,
|
||||
eith_extract_image,
|
||||
extract_image,
|
||||
xref_to_image,
|
||||
)
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.locations import TEST_DATA_DIR
|
||||
from test.utils.comparison import metadata_equal, image_sets_equal
|
||||
from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes
|
||||
from test.utils.generation.pdf import add_image, pdf_stream
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||
@ -95,7 +89,4 @@ def test_bad_xref_handling(bad_xref_pdf, dvc_test_data):
|
||||
metadata = extract_valid_metadata(doc, first(doc))
|
||||
xref = first(metadata)[Info.XREF]
|
||||
|
||||
with pytest.raises(BadXref):
|
||||
extract_image(doc, xref)
|
||||
|
||||
assert eith_extract_image(doc, xref).is_left()
|
||||
assert not xref_to_image(doc, xref)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user