Pull request #39: RED-6084 Improve image extraction speed

Merge in RR/image-prediction from RED-6084-adhoc-scanned-pages-filtering-refactoring to master

Squashed commit of the following:

commit bd6d83e7363b1c1993babcceb434110a6312c645
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Thu Feb 9 16:08:25 2023 +0100

    Tweak logging

commit 55bdd48d2a3462a8b4a6b7194c4a46b21d74c455
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Thu Feb 9 15:47:31 2023 +0100

    Update dependencies

commit 970275b25708c05e4fbe78b52aa70d791d5ff17a
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Thu Feb 9 15:35:37 2023 +0100

    Refactoring

    Make alpha channel check monadic to streamline error handling

commit e99e97e23fd8ce16f9a421d3e5442fccacf71ead
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Tue Feb 7 14:32:29 2023 +0100

    Refactoring

    - Rename
    - Refactor image extraction functions

commit 76b1b0ca2401495ec03ba2b6483091b52732eb81
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Tue Feb 7 11:55:30 2023 +0100

    Refactoring

commit cb1c461049d7c43ec340302f466447da9f95a499
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Tue Feb 7 11:44:01 2023 +0100

    Refactoring

commit 092069221a85ac7ac19bf838dcbc7ab1fde1e12b
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Tue Feb 7 10:18:53 2023 +0100

    Add to-do

commit 3cea4dad2d9703b8c79ddeb740b66a3b8255bb2a
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Tue Feb 7 10:11:35 2023 +0100

    Refactoring

    - Rename
    - Add typehints everywhere

commit 865e0819a14c420bc2edff454d41092c11c019a4
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 19:38:57 2023 +0100

    Add type explanation

commit 01d3d5d33f1ccb05aea1cec1d1577572b1a4deaa
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 19:37:49 2023 +0100

    Formatting

commit dffe1c18fc3a322a6b08890d4438844e8122faaf
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 19:34:13 2023 +0100

    [WIP] Either refactoring

    Add alternative formulation for monadic chain

commit 066cf17add404a313520cd794c06e3264cf971c9
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 18:40:30 2023 +0100

    [WIP] Either refactoring

commit f53f0fea298cdab88deb090af328b34d37e0198e
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 18:18:34 2023 +0100

    [WIP] Either refactoring

    Propagate error and metadata

commit 274a5f56d4fcb9c67fac5cf43e9412ec1ab5179e
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 17:51:35 2023 +0100

    [WIP] Either refactoring

    Fix test assertion

commit 3235a857f6e418e50484cbfff152b0f63efb2f53
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 16:57:31 2023 +0100

    [WIP] Either-refactoring

    Replace Maybe with Either to allow passing on error information or
    metadata which otherwise get sucked up by Nothing.

commit 89989543d87490f8b20a0a76055605d34345e8f4
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 16:12:40 2023 +0100

    [WIP] Monadic refactoring

    Integrate image validation step into monadic chain.

    At the moment we lost the error information through this. Refactoring to
    Either monad can bring it back.

commit 022bd4856a51aa085df5fe983fd77b99b53d594c
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 15:16:41 2023 +0100

    [WIP] Monadic refactoring

commit ca3898cb539607c8c3dd01c57e60211a5fea8a7d
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 15:10:34 2023 +0100

    [WIP] Monadic refactoring

commit d8f37bed5cbd6bdd2a0b52bae46fcdbb50f9dff2
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 15:09:51 2023 +0100

    [WIP] Monadic refactoring

commit 906fee0e5df051f38076aa1d2725e52a182ade13
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date:   Mon Feb 6 15:03:35 2023 +0100

    [WIP] Monadic refactoring

... and 35 more commits
This commit is contained in:
Matthias Bisping 2023-02-10 08:33:13 +01:00 committed by Julius Unverfehrt
parent 25fc7d84b9
commit 5cdf93b923
25 changed files with 395 additions and 324 deletions

View File

@ -1,6 +1,6 @@
webserver:
host: $SERVER_HOST|"127.0.0.1" # webserver address
port: $SERVER_PORT|5000 # webserver port
host: $SERVER_HOST|"127.0.0.1" # Webserver address
port: $SERVER_PORT|5000 # Webserver port
service:
logging_level: $LOGGING_LEVEL_ROOT|INFO # Logging level for service logger

View File

@ -1,5 +1,3 @@
from typing import Iterable
from funcy import juxt
from image_prediction.classifier.classifier import Classifier
@ -7,7 +5,6 @@ from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.compositor.compositor import TransformerCompositor
from image_prediction.encoder.encoders.hash_encoder import HashEncoder
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.formatter.formatter import format_image_plus
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
@ -17,7 +14,6 @@ from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
from image_prediction.transformer.transformers.response import ResponseTransformer
from pdf2img.extraction import extract_images_via_metadata
def get_mlflow_model_loader(mlruns_dir):
@ -30,17 +26,10 @@ def get_image_classifier(model_loader, model_identifier):
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
def get_dispatched_extract(**kwargs):
def get_extractor(**kwargs):
image_extractor = ParsablePDFImageExtractor(**kwargs)
def extract(pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None):
if metadata_per_image:
image_pluses = extract_images_via_metadata(pdf, metadata_per_image)
yield from map(format_image_plus, image_pluses)
else:
yield from image_extractor.extract(pdf, page_range)
return extract
return image_extractor
def get_formatter():

View File

@ -36,3 +36,7 @@ class InvalidBox(Exception):
class ParsingError(Exception):
pass
class BadXref(ValueError):
pass

View File

@ -1,10 +1,6 @@
import abc
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.transformer.transformer import Transformer
from pdf2img.default_objects.image import ImagePlus
class Formatter(Transformer):
@ -17,19 +13,3 @@ class Formatter(Transformer):
def __call__(self, obj):
return self.format(obj)
def format_image_plus(image: ImagePlus) -> ImageMetadataPair:
enum_metadata = {
Info.PAGE_WIDTH: image.info.pageInfo.width,
Info.PAGE_HEIGHT: image.info.pageInfo.height,
Info.PAGE_IDX: image.info.pageInfo.number,
Info.ALPHA: image.info.alpha,
Info.WIDTH: image.info.boundingBox.width,
Info.HEIGHT: image.info.boundingBox.height,
Info.X1: image.info.boundingBox.x0,
Info.X2: image.info.boundingBox.x1,
Info.Y1: image.info.boundingBox.y0,
Info.Y2: image.info.boundingBox.y1,
}
return ImageMetadataPair(image.aspil(), enum_metadata)

View File

@ -1,26 +1,32 @@
import atexit
import io
import json
import traceback
from functools import partial, lru_cache
from itertools import chain, starmap, filterfalse
from operator import itemgetter, truth
from typing import List, Iterable, Iterator
from itertools import chain, filterfalse
from operator import itemgetter, attrgetter
from typing import List, Union, Iterable, Any
import fitz
import numpy as np
from PIL import Image
from funcy import rcompose, merge, pluck, curry, compose
from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial
from pymonad.either import Right, Left, Either
from pymonad.tools import curry, identity
from image_prediction.exceptions import InvalidBox, BadXref
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs
from image_prediction.stitching.utils import validate_box_coords, validate_box_size
from image_prediction.stitching.utils import validate_box
from image_prediction.utils import get_logger
from image_prediction.utils.generic import lift
from image_prediction.utils.generic import bottom, left, right, lift, wrap_right
logger = get_logger()
Doc = fitz.fitz.Document
Pge = fitz.fitz.Page
Img = Image.Image
Pxm = fitz.fitz.Pixmap
class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self, verbose=False, tolerance=0):
@ -31,7 +37,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched
together
"""
self.doc: fitz.fitz.Document = None
self.doc: Union[Doc, None] = None
self.verbose = verbose
self.tolerance = tolerance
@ -44,80 +50,215 @@ class ParsablePDFImageExtractor(ImageExtractor):
yield from image_metadata_pairs
def __process_images_on_page(self, page: fitz.fitz.Page):
images = get_images_on_page(self.doc, page)
metadata = get_metadata_for_images_on_page(self.doc, page)
def __process_images_on_page(self, page: Pge):
metadata = extract_valid_metadata(self.doc, page)
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
valid_image_metadata_pairs = lkeep(
rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image
)
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
clear_caches()
image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata)))
# TODO: In the future, consider to introduce an image validator as a pipeline component rather than doing the
# validation here. Invalid images can then be split into a different stream and joined with the intact images
# again for the formatting step.
image_metadata_pairs = self.__filter_valid_images(image_metadata_pairs)
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
yield from valid_image_metadata_pairs_stitched
yield from image_metadata_pairs
@staticmethod
def __filter_valid_images(image_metadata_pairs: Iterable[ImageMetadataPair]) -> Iterator[ImageMetadataPair]:
def validate(image: Image.Image, metadata: dict):
try:
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
image.resize((100, 100)).convert("RGB")
return ImageMetadataPair(image, metadata)
except (OSError, Exception) as err:
metadata = json.dumps(EnumFormatter()(metadata), indent=2)
logger.warning(f"Invalid image encountered. Image metadata:\n{metadata}\n\n{traceback.format_exc()}")
return None
return filter(truth, starmap(validate, image_metadata_pairs))
def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either:
return metadatum_to_image_metadata_pair(self.doc, metadatum)
def extract_pages(doc, page_range):
def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]:
page_range = range(page_range.start + 1, page_range.stop + 1)
pages = map(doc.load_page, page_range)
yield from pages
@lru_cache(maxsize=None)
def get_images_on_page(doc, page: fitz.Page):
image_infos = get_image_infos(page)
xrefs = map(itemgetter("xref"), image_infos)
images = map(partial(xref_to_image, doc), xrefs)
yield from images
def validate_image(image: Img) -> Either:
try:
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
image.resize((100, 100)).convert("RGB")
return Right(image)
except (OSError, Exception):
logger.warning(f"Invalid image encountered.")
return Left("Invalid image.")
def get_metadata_for_images_on_page(doc, page: fitz.Page):
def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
return compose(
list,
partial(add_alpha_channel_info, doc),
filter_valid_metadata,
get_metadata_for_images_on_page,
)(page)
metadata = map(get_image_metadata, get_image_infos(page))
metadata = validate_coords_and_passthrough(metadata)
metadata = filter_out_tiny_images(metadata)
metadata = validate_size_and_passthrough(metadata)
def take_good_log_bad(item: Either, log_formatter=identity) -> Any:
return item.either(rpartial(log_error_context, log_formatter), identity)
metadata = add_page_metadata(page, metadata)
metadata = add_alpha_channel_info(doc, page, metadata)
def format_context(context: dict) -> str:
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
def log_error_context(context: dict, formatter=identity) -> None:
logger.warning(f"Skipping bad image. {formatter(context)}")
return None
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
return image_metadata_pair
def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]:
def add_alpha_value_to_metadatum(metadatum: dict) -> dict:
alpha = metadatum_to_alpha_value(metadatum)
return {**metadatum, Info.ALPHA: alpha}
xref_to_alpha = partial(has_alpha_channel, doc)
metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF))
yield from map(add_alpha_value_to_metadatum, metadata)
def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata)
def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
metadata = compose(
partial(add_page_metadata, page),
lift(get_image_metadata),
get_image_infos,
)(page)
yield from metadata
@lru_cache(maxsize=None)
def get_image_infos(page: fitz.Page) -> List[dict]:
return page.get_image_info(xrefs=True)
@curry(2)
def eith_extract_image(doc: Doc, xref: int) -> Either:
try:
return Right(extract_image(doc, xref))
except BadXref:
return Left("Bad xref.")
@lru_cache(maxsize=None)
def xref_to_image(doc, xref) -> Image:
maybe_image = load_image_handle_from_xref(doc, xref)
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either:
"""Reference: haskell.org/tutorial/monads.html"""
def context(value):
return {"reason": value, "metadata": metadata.either(bottom, identity)}
# Explicitly we are doing the following. (1) and (2) are equivalent.
# a := Image
# b := Metadata
# c := ImageMetadataPair
# m := Either monad
# fmt: off
# 1)
# pair: Either = (
# Right(make_image_metadata_pair) # m (a -> b -> c)
# .amap(image) # m (a -> b -> c) <*> m a = m (b -> c)
# .amap(metadata) # m (b -> c) <*> m b = m c
# )
# 2)
# pair: Either = (
# image.bind(right(make_image_metadata_pair)) # m a >>= m (a -> b -> c) = m (b -> c)
# .amap(metadata) # m (b -> c) <*> m b = m c
# )
# fmt: on
# Syntactic sugar variant with details hidden
pair: Either = Either.apply(make_image_metadata_pair).to_arguments(image, metadata)
return pair.either(left(context), right(identity))
def get_image_metadata(image_info):
@curry(2)
def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
return ImageMetadataPair(image, metadatum)
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
def extract_image(doc: Doc, xref: int) -> Any:
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
def pixmap_to_image(pixmap: Pxm) -> Img:
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
array = normalize_channels(array)
return Image.fromarray(array)
def extract_pixmap(doc: Doc, xref: int) -> Pxm:
try:
return fitz.Pixmap(doc, xref)
except ValueError as err:
msg = f"Cross reference {xref} is invalid, skipping extraction."
logger.error(err)
logger.trace(msg)
raise BadXref(msg) from err
def has_alpha_channel(doc: Doc, xref: int) -> bool:
_get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc)
_extract_pixmap = wrap_right(extract_pixmap)(doc)
def get_soft_mask_reference(cross_reference: int) -> Either:
def error(value) -> str:
return f"Invalid soft mask {value} for cross reference {cross_reference}."
logger.trace(f"Getting soft mask handle for cross reference {cross_reference}.")
pass_on_if_not_none = iffy(notnone, right(identity), left(error))
return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none)
def mask_exists(soft_mask_reference: int) -> Either:
logger.trace(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.")
return _get_image_handle(soft_mask_reference).then(notnone)
def image_has_alpha_channel(reference: int) -> Either:
logger.trace(f"Checking if image with reference {reference} has alpha channel.")
return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool)
logger.debug(f"Checking if image with cross reference {xref} has alpha channel.")
cross_reference = Right(xref)
soft_mask_reference = cross_reference.bind(get_soft_mask_reference)
return any(
take_good_log_bad(reference.bind(check))
for reference, check in [
(soft_mask_reference, mask_exists),
(soft_mask_reference, image_has_alpha_channel),
(cross_reference, image_has_alpha_channel),
]
)
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
yield from filterfalse(tiny, metadata)
def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
def __validate_box(box):
try:
return validate_box(box)
except InvalidBox as err:
logger.debug(f"Dropping invalid metadatum, reason: {err}")
yield from keep(__validate_box, metadata)
def get_image_metadata(image_info: dict) -> dict:
xref, coords = itemgetter("xref", "bbox")(image_info)
x1, y1, x2, y2 = map(rounder, coords)
width = abs(x2 - x1)
height = abs(y2 - y1)
@ -129,47 +270,40 @@ def get_image_metadata(image_info):
Info.X2: x2,
Info.Y1: y1,
Info.Y2: y2,
Info.XREF: xref,
}
def validate_coords_and_passthrough(metadata):
yield from map(validate_box_coords, metadata)
@lru_cache(maxsize=None)
def get_image_infos(page: Pge) -> List[dict]:
return page.get_image_info(xrefs=True)
def filter_out_tiny_images(metadata):
yield from filterfalse(tiny, metadata)
def validate_size_and_passthrough(metadata):
yield from map(validate_box_size, metadata)
def add_page_metadata(page, metadata):
def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]:
yield from map(partial(merge, get_page_metadata(page)), metadata)
def add_alpha_channel_info(doc, page, metadata):
def normalize_channels(array: np.ndarray):
if not array.ndim == 3:
array = np.expand_dims(array, axis=-1)
page_to_xrefs = compose(curry(pluck)("xref"), get_image_infos)
xref_to_alpha = partial(has_alpha_channel, doc)
page_to_alpha_value_per_image = compose(lift(xref_to_alpha), page_to_xrefs)
alpha_to_dict = compose(dict, lambda a: [(Info.ALPHA, a)])
page_to_alpha_mapping_per_image = compose(lift(alpha_to_dict), page_to_alpha_value_per_image)
if array.shape[-1] == 4:
array = array[..., :3]
elif array.shape[-1] == 1:
array = np.concatenate([array, array, array], axis=-1)
elif array.shape[-1] != 3:
logger.warning(f"Unexpected image format: {array.shape}.")
raise ValueError(f"Unexpected image format: {array.shape}.")
metadata = starmap(merge, zip(page_to_alpha_mapping_per_image(page), metadata))
yield from metadata
return array
@lru_cache(maxsize=None)
def load_image_handle_from_xref(doc, xref):
def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]:
return doc.extract_image(xref)
rounder = rcompose(round, int)
def get_page_metadata(page):
def get_page_metadata(page: Pge) -> dict:
page_width, page_height = map(rounder, page.mediabox_size)
return {
@ -179,30 +313,17 @@ def get_page_metadata(page):
}
def has_alpha_channel(doc, xref):
maybe_image = load_image_handle_from_xref(doc, xref)
maybe_smask = maybe_image["smask"] if maybe_image else None
if maybe_smask:
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
else:
try:
return bool(fitz.Pixmap(doc, xref).alpha)
except ValueError:
logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '<no title>')}.")
return False
rounder = rcompose(round, int)
def tiny(metadata):
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
def tiny(metadatum: dict) -> bool:
return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4
def clear_caches():
def clear_caches() -> None:
get_image_infos.cache_clear()
load_image_handle_from_xref.cache_clear()
get_images_on_page.cache_clear()
xref_to_image.cache_clear()
get_image_handle.cache_clear()
eith_extract_image.cache_clear()
atexit.register(clear_caches)

View File

@ -12,3 +12,4 @@ class Info(Enum):
Y1 = "y1"
Y2 = "y2"
ALPHA = "alpha"
XREF = "xref"

View File

@ -3,15 +3,14 @@
from pathlib import Path
MODULE_DIR = Path(__file__).resolve().parents[0]
PACKAGE_ROOT_DIR = MODULE_DIR.parents[0]
CONFIG_FILE = PACKAGE_ROOT_DIR / "config.yaml"
BANNER_FILE = PACKAGE_ROOT_DIR / "banner.txt"
DATA_DIR = PACKAGE_ROOT_DIR / "data"
MLRUNS_DIR = str(DATA_DIR / "mlruns")
TEST_DATA_DIR = PACKAGE_ROOT_DIR / "test" / "data"
TEST_DIR = PACKAGE_ROOT_DIR / "test"
TEST_DATA_DIR = TEST_DIR / "data"
TEST_DATA_DIR_DVC = TEST_DIR / "data.dvc"

View File

@ -11,8 +11,8 @@ from image_prediction.default_objects import (
get_formatter,
get_mlflow_model_loader,
get_image_classifier,
get_extractor,
get_encoder,
get_dispatched_extract,
)
from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift
@ -41,7 +41,7 @@ class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs):
self.verbose = verbose
extract = get_dispatched_extract(**kwargs)
extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter()
represent = get_encoder()
@ -63,9 +63,9 @@ class Pipeline:
reformat, # ... the items
)
def __call__(self, pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None):
def __call__(self, pdf: bytes, page_range: range = None):
yield from tqdm(
self.pipe(pdf, page_range=page_range, metadata_per_image=metadata_per_image),
self.pipe(pdf, page_range=page_range),
desc="Processing images from document",
unit=" images",
disable=not self.verbose,

View File

@ -21,11 +21,6 @@ class ResponseTransformer(Transformer):
def build_image_info(data: dict) -> dict:
def compute_geometric_quotient():
page_area_sqrt = math.sqrt(abs(page_width * page_height))
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
return image_area_sqrt / page_area_sqrt
page_width, page_height, x1, x2, y1, y2, width, height, alpha = itemgetter(
"page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height", "alpha"
)(data)
@ -34,7 +29,7 @@ def build_image_info(data: dict) -> dict:
label = classification["label"]
representation = data["representation"]
geometric_quotient = round(compute_geometric_quotient(), 4)
geometric_quotient = round(compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1), 4)
min_image_to_page_quotient_breached = bool(
geometric_quotient < get_class_specific_min_image_to_page_quotient(label)
@ -89,6 +84,12 @@ def build_image_info(data: dict) -> dict:
return image_info
def compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1):
page_area_sqrt = math.sqrt(abs(page_width * page_height))
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
return image_area_sqrt / page_area_sqrt
def get_class_specific_min_image_to_page_quotient(label, table=None):
return get_class_specific_value(
"REL_IMAGE_SIZE", label, "min", CONFIG.filters.image_to_page_quotient.min, table=table

View File

@ -1,6 +1,15 @@
from functools import wraps
from inspect import signature
from itertools import starmap
from typing import Callable
from funcy import iterate, first, curry, map
from pymonad.either import Left, Right, Either
from pymonad.tools import curry as pmcurry
from image_prediction.utils import get_logger
logger = get_logger()
def until(cond, func, *args, **kwargs):
@ -13,3 +22,63 @@ def lift(fn):
def starlift(fn):
return curry(starmap)(fn)
def bottom(*args, **kwargs):
return False
def top(*args, **kwargs):
return True
def left(fn):
@wraps(fn)
def inner(x):
return Left(fn(x))
return inner
def right(fn):
@wraps(fn)
def inner(x):
return Right(fn(x))
return inner
def wrap_left(fn, success_condition=top, error_message=None) -> Callable:
return wrap_either(Left, Right, success_condition=success_condition, error_message=error_message)(fn)
def wrap_right(fn, success_condition=top, error_message=None) -> Callable:
return wrap_either(Right, Left, success_condition=success_condition, error_message=error_message)(fn)
def wrap_either(success_type, failure_type, success_condition=top, error_message=None) -> Callable:
@wraps(wrap_either)
def wrapper(fn) -> Callable:
n_params = len(signature(fn).parameters)
@pmcurry(n_params)
@wraps(fn)
def wrapper(*args, **kwargs) -> Either:
try:
result = fn(*args, **kwargs)
if success_condition(result):
return success_type(result)
else:
return failure_type({"error": error_message, "result": result})
except Exception as err:
logger.error(err)
return failure_type({"error": error_message or err, "result": Void})
return wrapper
return wrapper
class Void:
pass

View File

@ -23,3 +23,5 @@ pdf2image==1.16.0
frozendict==2.3.0
protobuf<=3.20.*
prometheus-client==0.13.1
fsspec==2022.11.0
PyMonad==2.4.0

View File

@ -2,7 +2,6 @@ import argparse
import json
import os
from glob import glob
from operator import truth
from image_prediction.pipeline import load_pipeline
from image_prediction.utils import get_logger
@ -15,7 +14,6 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("input", help="pdf file or directory")
parser.add_argument("--metadata", help="optional figure detection metadata")
parser.add_argument("--print", "-p", help="print output to terminal", action="store_true", default=False)
parser.add_argument("--page_interval", "-i", help="page interval [i, j), min index = 0", nargs=2, type=int)
@ -24,17 +22,13 @@ def parse_args():
return args
def process_pdf(pipeline, pdf_path, metadata=None, page_range=None):
if metadata:
with open(metadata) as f:
metadata = json.load(f)
def process_pdf(pipeline, pdf_path, page_range=None):
with open(pdf_path, "rb") as f:
logger.info(f"Processing {pdf_path}")
predictions = list(pipeline(f.read(), page_range=page_range, metadata_per_image=metadata))
predictions = list(pipeline(f.read(), page_range=page_range))
annotate_pdf(
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf")))
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf")))
)
return predictions
@ -48,10 +42,9 @@ def main(args):
else:
pdf_paths = glob(os.path.join(args.input, "*.pdf"))
page_range = range(*args.page_interval) if args.page_interval else None
metadata = args.metadata if args.metadata else None
for pdf_path in pdf_paths:
predictions = process_pdf(pipeline, pdf_path, metadata, page_range=page_range)
predictions = process_pdf(pipeline, pdf_path, page_range=page_range)
if args.print:
print(pdf_path)
print(json.dumps(predictions, indent=2))

View File

@ -1,5 +1,4 @@
import gzip
import io
import json
import logging
@ -29,36 +28,28 @@ logger.setLevel(PYINFRA_CONFIG.logging_level_root)
def process_request(request_message):
dossier_id = request_message["dossierId"]
file_id = request_message["fileId"]
logger.info(f"Processing {dossier_id=} {file_id=} ...")
target_file_name = f"{dossier_id}/{file_id}.{request_message['targetFileExtension']}"
response_file_name = f"{dossier_id}/{file_id}.{request_message['responseFileExtension']}"
figure_data_file_name = f"{dossier_id}/{file_id}.FIGURE.json.gz"
bucket = PYINFRA_CONFIG.storage_bucket
storage = get_storage(PYINFRA_CONFIG)
pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size)
if storage.exists(bucket, target_file_name):
should_publish_result = True
if not storage.exists(bucket, target_file_name):
publish_result = False
else:
publish_result = True
object_bytes = storage.get_object(bucket, target_file_name)
object_bytes = gzip.decompress(object_bytes)
classifications = list(pipeline(pdf=object_bytes))
if storage.exists(bucket, figure_data_file_name):
metadata_bytes = storage.get_object(bucket, figure_data_file_name)
metadata_bytes = gzip.decompress(metadata_bytes)
metadata_per_image = json.load(io.BytesIO(metadata_bytes))["data"]
classifications_cv = list(pipeline(pdf=object_bytes, metadata_per_image=metadata_per_image))
else:
classifications_cv = []
result = {**request_message, "data": classifications, "dataCV": classifications_cv}
result = {**request_message, "data": classifications}
storage_bytes = gzip.compress(json.dumps(result).encode("utf-8"))
storage.put_object(bucket, response_file_name, storage_bytes)
else:
should_publish_result = False
return should_publish_result, {"dossierId": dossier_id, "fileId": file_id}
return publish_result, {"dossierId": dossier_id, "fileId": file_id}
def main():

1
test/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/data

5
test/data.dvc Normal file
View File

@ -0,0 +1,5 @@
outs:
- md5: 4b0fec291ce0661b3efbbd8b80f4f514.dir
size: 107332
nfiles: 4
path: data

View File

@ -1,44 +0,0 @@
[
{
"classification": {
"label": "formula",
"probabilities": {
"formula": 1.0,
"logo": 0.0,
"other": 0.0,
"signature": 0.0
}
},
"representation": "FFFEF0C7033648170F3EFFFFF",
"position": {
"x1": 321,
"x2": 515,
"y1": 348,
"y2": 542,
"pageNumber": 2
},
"geometry": {
"width": 194,
"height": 194
},
"alpha": false,
"filters": {
"geometry": {
"imageSize": {
"quotient": 0.2741,
"tooLarge": false,
"tooSmall": false
},
"imageFormat": {
"quotient": 1.0,
"tooTall": false,
"tooWide": false
}
},
"probability": {
"unconfident": false
},
"allPassed": true
}
}
]

View File

@ -1,92 +0,0 @@
{
"input": [
{
"width": 100,
"height": 8,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 0,
"x2": 100,
"y2": 8
},
{
"width": 100,
"height": 9,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 9,
"x2": 100,
"y2": 18
},
{
"width": 100,
"height": 35,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 18,
"x2": 100,
"y2": 53
},
{
"width": 47,
"height": 46,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 54,
"x2": 47,
"y2": 100
},
{
"width": 31,
"height": 46,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 48,
"y1": 54,
"x2": 79,
"y2": 100
},
{
"width": 20,
"height": 19,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 80,
"y1": 54,
"x2": 100,
"y2": 73
},
{
"width": 20,
"height": 27,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 80,
"y1": 73,
"x2": 100,
"y2": 100
}
],
"target": {
"width": 100,
"height": 100,
"page_idx": 0,
"page_width": 100,
"page_height": 100,
"x1": 0,
"y1": 0,
"x2": 100,
"y2": 100
}
}

View File

@ -1,7 +1,21 @@
import numpy as np
import pytest
from dvc.repo import Repo
from image_prediction.locations import PACKAGE_ROOT_DIR, TEST_DATA_DIR_DVC
from image_prediction.utils import get_logger
logger = get_logger()
@pytest.fixture
def input_batch(batch_size, input_size):
return np.random.random_sample(size=(batch_size, *input_size))
@pytest.fixture(scope="session")
def dvc_test_data():
logger.info("Pulling data with DVC...")
# noinspection PyCallingNonCallable
Repo(PACKAGE_ROOT_DIR).pull(targets=[str(TEST_DATA_DIR_DVC)])
logger.info("Finished pulling data.")

12
test/fixtures/pdf.py vendored
View File

@ -4,7 +4,7 @@ import fpdf
import pytest
from image_prediction.locations import TEST_DATA_DIR
from test.utils.generation.pdf import add_image, pdf_stream
from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes
@pytest.fixture
@ -18,6 +18,10 @@ def pdf(image_metadata_pairs):
@pytest.fixture
def real_pdf():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
yield f.read()
def real_pdf(dvc_test_data):
yield from stream_pdf_bytes(TEST_DATA_DIR / "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf")
@pytest.fixture
def bad_xref_pdf(dvc_test_data):
yield from stream_pdf_bytes(TEST_DATA_DIR / "bad_xref.pdf")

View File

@ -87,7 +87,7 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
@pytest.fixture
def real_expected_service_response():
def real_expected_service_response(dvc_test_data):
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
yield json.load(f)

View File

@ -1,3 +1,4 @@
import json
import random
from operator import itemgetter
@ -7,12 +8,23 @@ import pytest
from PIL import Image
from funcy import first, rest
from image_prediction.exceptions import BadXref
from image_prediction.extraction import extract_images_from_pdf
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel
from image_prediction.image_extractor.extractors.parsable import (
extract_pages,
has_alpha_channel,
get_image_infos,
ParsablePDFImageExtractor,
extract_valid_metadata,
eith_extract_image,
extract_image,
)
from image_prediction.info import Info
from image_prediction.locations import TEST_DATA_DIR
from test.utils.comparison import metadata_equal, image_sets_equal
from test.utils.generation.pdf import add_image, pdf_stream
from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes
@pytest.mark.parametrize("extractor_type", ["mock"])
@ -75,3 +87,15 @@ def test_has_alpha_channel(base_patch_metadata, suffix, mode):
assert not list(rest(xrefs))
doc.close()
def test_bad_xref_handling(bad_xref_pdf, dvc_test_data):
doc = fitz.Document(stream=bad_xref_pdf)
metadata = extract_valid_metadata(doc, first(doc))
xref = first(metadata)[Info.XREF]
with pytest.raises(BadXref):
extract_image(doc, xref)
assert eith_extract_image(doc, xref).is_left()

View File

@ -60,10 +60,10 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
def test_image_stitcher_with_gaps_must_succeed():
def test_image_stitcher_with_gaps_must_succeed(dvc_test_data):
from image_prediction.locations import TEST_DATA_DIR
with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f:
with open(TEST_DATA_DIR / "stitching_with_tolerance.json") as f:
patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f)))
images = map(gray_image_from_metadata, patches_metadata)

View File

@ -1,12 +1,15 @@
from functools import partial
from itertools import starmap, product, repeat
from typing import Iterable
import numpy as np
from PIL.Image import Image
from frozendict import frozendict
from funcy import ilen
from funcy import ilen, compose, omit
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.info import Info
from image_prediction.utils.generic import lift
def transform_equal(a, b):
@ -18,7 +21,8 @@ def images_equal(im1: Image, im2: Image, **kwargs):
def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]):
return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2))
f = compose(set, lift(compose(frozendict, partial(omit, keys=[Info.XREF]))))
return f(mdat1) == f(mdat2)
def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]):

View File

@ -28,3 +28,8 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix):
with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image:
image.save(temp_image.name)
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix)
def stream_pdf_bytes(path: str):
with open(path, "rb") as f:
yield f.read()