Pull request #39: RED-6084 Improve image extraction speed
Merge in RR/image-prediction from RED-6084-adhoc-scanned-pages-filtering-refactoring to master
Squashed commit of the following:
commit bd6d83e7363b1c1993babcceb434110a6312c645
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Thu Feb 9 16:08:25 2023 +0100
Tweak logging
commit 55bdd48d2a3462a8b4a6b7194c4a46b21d74c455
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Thu Feb 9 15:47:31 2023 +0100
Update dependencies
commit 970275b25708c05e4fbe78b52aa70d791d5ff17a
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Thu Feb 9 15:35:37 2023 +0100
Refactoring
Make alpha channel check monadic to streamline error handling
commit e99e97e23fd8ce16f9a421d3e5442fccacf71ead
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Tue Feb 7 14:32:29 2023 +0100
Refactoring
- Rename
- Refactor image extraction functions
commit 76b1b0ca2401495ec03ba2b6483091b52732eb81
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Tue Feb 7 11:55:30 2023 +0100
Refactoring
commit cb1c461049d7c43ec340302f466447da9f95a499
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Tue Feb 7 11:44:01 2023 +0100
Refactoring
commit 092069221a85ac7ac19bf838dcbc7ab1fde1e12b
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Tue Feb 7 10:18:53 2023 +0100
Add to-do
commit 3cea4dad2d9703b8c79ddeb740b66a3b8255bb2a
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Tue Feb 7 10:11:35 2023 +0100
Refactoring
- Rename
- Add typehints everywhere
commit 865e0819a14c420bc2edff454d41092c11c019a4
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 19:38:57 2023 +0100
Add type explanation
commit 01d3d5d33f1ccb05aea1cec1d1577572b1a4deaa
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 19:37:49 2023 +0100
Formatting
commit dffe1c18fc3a322a6b08890d4438844e8122faaf
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 19:34:13 2023 +0100
[WIP] Either refactoring
Add alternative formulation for monadic chain
commit 066cf17add404a313520cd794c06e3264cf971c9
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 18:40:30 2023 +0100
[WIP] Either refactoring
commit f53f0fea298cdab88deb090af328b34d37e0198e
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 18:18:34 2023 +0100
[WIP] Either refactoring
Propagate error and metadata
commit 274a5f56d4fcb9c67fac5cf43e9412ec1ab5179e
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 17:51:35 2023 +0100
[WIP] Either refactoring
Fix test assertion
commit 3235a857f6e418e50484cbfff152b0f63efb2f53
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 16:57:31 2023 +0100
[WIP] Either-refactoring
Replace Maybe with Either to allow passing on error information or
metadata which otherwise get sucked up by Nothing.
commit 89989543d87490f8b20a0a76055605d34345e8f4
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 16:12:40 2023 +0100
[WIP] Monadic refactoring
Integrate image validation step into monadic chain.
At the moment we lost the error information through this. Refactoring to
Either monad can bring it back.
commit 022bd4856a51aa085df5fe983fd77b99b53d594c
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 15:16:41 2023 +0100
[WIP] Monadic refactoring
commit ca3898cb539607c8c3dd01c57e60211a5fea8a7d
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 15:10:34 2023 +0100
[WIP] Monadic refactoring
commit d8f37bed5cbd6bdd2a0b52bae46fcdbb50f9dff2
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 15:09:51 2023 +0100
[WIP] Monadic refactoring
commit 906fee0e5df051f38076aa1d2725e52a182ade13
Author: Matthias Bisping <matthias.bisping@axbit.com>
Date: Mon Feb 6 15:03:35 2023 +0100
[WIP] Monadic refactoring
... and 35 more commits
This commit is contained in:
parent
25fc7d84b9
commit
5cdf93b923
@ -1,6 +1,6 @@
|
||||
webserver:
|
||||
host: $SERVER_HOST|"127.0.0.1" # webserver address
|
||||
port: $SERVER_PORT|5000 # webserver port
|
||||
host: $SERVER_HOST|"127.0.0.1" # Webserver address
|
||||
port: $SERVER_PORT|5000 # Webserver port
|
||||
|
||||
service:
|
||||
logging_level: $LOGGING_LEVEL_ROOT|INFO # Logging level for service logger
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Iterable
|
||||
|
||||
from funcy import juxt
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
@ -7,7 +5,6 @@ from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.compositor.compositor import TransformerCompositor
|
||||
from image_prediction.encoder.encoders.hash_encoder import HashEncoder
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.formatter.formatter import format_image_plus
|
||||
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
@ -17,7 +14,6 @@ from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
|
||||
from image_prediction.transformer.transformers.response import ResponseTransformer
|
||||
from pdf2img.extraction import extract_images_via_metadata
|
||||
|
||||
|
||||
def get_mlflow_model_loader(mlruns_dir):
|
||||
@ -30,17 +26,10 @@ def get_image_classifier(model_loader, model_identifier):
|
||||
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
|
||||
|
||||
|
||||
def get_dispatched_extract(**kwargs):
|
||||
def get_extractor(**kwargs):
|
||||
image_extractor = ParsablePDFImageExtractor(**kwargs)
|
||||
|
||||
def extract(pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None):
|
||||
if metadata_per_image:
|
||||
image_pluses = extract_images_via_metadata(pdf, metadata_per_image)
|
||||
yield from map(format_image_plus, image_pluses)
|
||||
else:
|
||||
yield from image_extractor.extract(pdf, page_range)
|
||||
|
||||
return extract
|
||||
return image_extractor
|
||||
|
||||
|
||||
def get_formatter():
|
||||
|
||||
@ -36,3 +36,7 @@ class InvalidBox(Exception):
|
||||
|
||||
class ParsingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BadXref(ValueError):
|
||||
pass
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
import abc
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
|
||||
from image_prediction.transformer.transformer import Transformer
|
||||
from pdf2img.default_objects.image import ImagePlus
|
||||
|
||||
|
||||
class Formatter(Transformer):
|
||||
@ -17,19 +13,3 @@ class Formatter(Transformer):
|
||||
|
||||
def __call__(self, obj):
|
||||
return self.format(obj)
|
||||
|
||||
|
||||
def format_image_plus(image: ImagePlus) -> ImageMetadataPair:
|
||||
enum_metadata = {
|
||||
Info.PAGE_WIDTH: image.info.pageInfo.width,
|
||||
Info.PAGE_HEIGHT: image.info.pageInfo.height,
|
||||
Info.PAGE_IDX: image.info.pageInfo.number,
|
||||
Info.ALPHA: image.info.alpha,
|
||||
Info.WIDTH: image.info.boundingBox.width,
|
||||
Info.HEIGHT: image.info.boundingBox.height,
|
||||
Info.X1: image.info.boundingBox.x0,
|
||||
Info.X2: image.info.boundingBox.x1,
|
||||
Info.Y1: image.info.boundingBox.y0,
|
||||
Info.Y2: image.info.boundingBox.y1,
|
||||
}
|
||||
return ImageMetadataPair(image.aspil(), enum_metadata)
|
||||
|
||||
@ -1,26 +1,32 @@
|
||||
import atexit
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
from functools import partial, lru_cache
|
||||
from itertools import chain, starmap, filterfalse
|
||||
from operator import itemgetter, truth
|
||||
from typing import List, Iterable, Iterator
|
||||
from itertools import chain, filterfalse
|
||||
from operator import itemgetter, attrgetter
|
||||
from typing import List, Union, Iterable, Any
|
||||
|
||||
import fitz
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from funcy import rcompose, merge, pluck, curry, compose
|
||||
from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial
|
||||
from pymonad.either import Right, Left, Either
|
||||
from pymonad.tools import curry, identity
|
||||
|
||||
from image_prediction.exceptions import InvalidBox, BadXref
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.stitching.stitching import stitch_pairs
|
||||
from image_prediction.stitching.utils import validate_box_coords, validate_box_size
|
||||
from image_prediction.stitching.utils import validate_box
|
||||
from image_prediction.utils import get_logger
|
||||
from image_prediction.utils.generic import lift
|
||||
from image_prediction.utils.generic import bottom, left, right, lift, wrap_right
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
Doc = fitz.fitz.Document
|
||||
Pge = fitz.fitz.Page
|
||||
Img = Image.Image
|
||||
Pxm = fitz.fitz.Pixmap
|
||||
|
||||
|
||||
class ParsablePDFImageExtractor(ImageExtractor):
|
||||
def __init__(self, verbose=False, tolerance=0):
|
||||
@ -31,7 +37,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched
|
||||
together
|
||||
"""
|
||||
self.doc: fitz.fitz.Document = None
|
||||
self.doc: Union[Doc, None] = None
|
||||
self.verbose = verbose
|
||||
self.tolerance = tolerance
|
||||
|
||||
@ -44,80 +50,215 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
|
||||
yield from image_metadata_pairs
|
||||
|
||||
def __process_images_on_page(self, page: fitz.fitz.Page):
|
||||
images = get_images_on_page(self.doc, page)
|
||||
metadata = get_metadata_for_images_on_page(self.doc, page)
|
||||
def __process_images_on_page(self, page: Pge):
|
||||
metadata = extract_valid_metadata(self.doc, page)
|
||||
|
||||
either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata)
|
||||
valid_image_metadata_pairs = lkeep(
|
||||
rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image
|
||||
)
|
||||
valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance)
|
||||
|
||||
clear_caches()
|
||||
|
||||
image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata)))
|
||||
# TODO: In the future, consider to introduce an image validator as a pipeline component rather than doing the
|
||||
# validation here. Invalid images can then be split into a different stream and joined with the intact images
|
||||
# again for the formatting step.
|
||||
image_metadata_pairs = self.__filter_valid_images(image_metadata_pairs)
|
||||
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
|
||||
yield from valid_image_metadata_pairs_stitched
|
||||
|
||||
yield from image_metadata_pairs
|
||||
|
||||
@staticmethod
|
||||
def __filter_valid_images(image_metadata_pairs: Iterable[ImageMetadataPair]) -> Iterator[ImageMetadataPair]:
|
||||
def validate(image: Image.Image, metadata: dict):
|
||||
try:
|
||||
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
|
||||
image.resize((100, 100)).convert("RGB")
|
||||
return ImageMetadataPair(image, metadata)
|
||||
except (OSError, Exception) as err:
|
||||
metadata = json.dumps(EnumFormatter()(metadata), indent=2)
|
||||
logger.warning(f"Invalid image encountered. Image metadata:\n{metadata}\n\n{traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
return filter(truth, starmap(validate, image_metadata_pairs))
|
||||
def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either:
|
||||
return metadatum_to_image_metadata_pair(self.doc, metadatum)
|
||||
|
||||
|
||||
def extract_pages(doc, page_range):
|
||||
def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]:
|
||||
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||
pages = map(doc.load_page, page_range)
|
||||
|
||||
yield from pages
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_images_on_page(doc, page: fitz.Page):
|
||||
image_infos = get_image_infos(page)
|
||||
xrefs = map(itemgetter("xref"), image_infos)
|
||||
images = map(partial(xref_to_image, doc), xrefs)
|
||||
|
||||
yield from images
|
||||
def validate_image(image: Img) -> Either:
|
||||
try:
|
||||
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
|
||||
image.resize((100, 100)).convert("RGB")
|
||||
return Right(image)
|
||||
except (OSError, Exception):
|
||||
logger.warning(f"Invalid image encountered.")
|
||||
return Left("Invalid image.")
|
||||
|
||||
|
||||
def get_metadata_for_images_on_page(doc, page: fitz.Page):
|
||||
def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]:
|
||||
return compose(
|
||||
list,
|
||||
partial(add_alpha_channel_info, doc),
|
||||
filter_valid_metadata,
|
||||
get_metadata_for_images_on_page,
|
||||
)(page)
|
||||
|
||||
metadata = map(get_image_metadata, get_image_infos(page))
|
||||
metadata = validate_coords_and_passthrough(metadata)
|
||||
|
||||
metadata = filter_out_tiny_images(metadata)
|
||||
metadata = validate_size_and_passthrough(metadata)
|
||||
def take_good_log_bad(item: Either, log_formatter=identity) -> Any:
|
||||
return item.either(rpartial(log_error_context, log_formatter), identity)
|
||||
|
||||
metadata = add_page_metadata(page, metadata)
|
||||
|
||||
metadata = add_alpha_channel_info(doc, page, metadata)
|
||||
def format_context(context: dict) -> str:
|
||||
return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}"
|
||||
|
||||
|
||||
def log_error_context(context: dict, formatter=identity) -> None:
|
||||
logger.warning(f"Skipping bad image. {formatter(context)}")
|
||||
return None
|
||||
|
||||
|
||||
def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either:
|
||||
image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image)
|
||||
image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum))
|
||||
return image_metadata_pair
|
||||
|
||||
|
||||
def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
def add_alpha_value_to_metadatum(metadatum: dict) -> dict:
|
||||
alpha = metadatum_to_alpha_value(metadatum)
|
||||
return {**metadatum, Info.ALPHA: alpha}
|
||||
|
||||
xref_to_alpha = partial(has_alpha_channel, doc)
|
||||
metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF))
|
||||
|
||||
yield from map(add_alpha_value_to_metadatum, metadata)
|
||||
|
||||
|
||||
def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata)
|
||||
|
||||
|
||||
def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]:
|
||||
metadata = compose(
|
||||
partial(add_page_metadata, page),
|
||||
lift(get_image_metadata),
|
||||
get_image_infos,
|
||||
)(page)
|
||||
|
||||
yield from metadata
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_image_infos(page: fitz.Page) -> List[dict]:
|
||||
return page.get_image_info(xrefs=True)
|
||||
@curry(2)
|
||||
def eith_extract_image(doc: Doc, xref: int) -> Either:
|
||||
try:
|
||||
return Right(extract_image(doc, xref))
|
||||
except BadXref:
|
||||
return Left("Bad xref.")
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def xref_to_image(doc, xref) -> Image:
|
||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
|
||||
def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either:
|
||||
"""Reference: haskell.org/tutorial/monads.html"""
|
||||
|
||||
def context(value):
|
||||
return {"reason": value, "metadata": metadata.either(bottom, identity)}
|
||||
|
||||
# Explicitly we are doing the following. (1) and (2) are equivalent.
|
||||
|
||||
# a := Image
|
||||
# b := Metadata
|
||||
# c := ImageMetadataPair
|
||||
# m := Either monad
|
||||
|
||||
# fmt: off
|
||||
# 1)
|
||||
# pair: Either = (
|
||||
# Right(make_image_metadata_pair) # m (a -> b -> c)
|
||||
# .amap(image) # m (a -> b -> c) <*> m a = m (b -> c)
|
||||
# .amap(metadata) # m (b -> c) <*> m b = m c
|
||||
# )
|
||||
|
||||
# 2)
|
||||
# pair: Either = (
|
||||
# image.bind(right(make_image_metadata_pair)) # m a >>= m (a -> b -> c) = m (b -> c)
|
||||
# .amap(metadata) # m (b -> c) <*> m b = m c
|
||||
# )
|
||||
# fmt: on
|
||||
|
||||
# Syntactic sugar variant with details hidden
|
||||
pair: Either = Either.apply(make_image_metadata_pair).to_arguments(image, metadata)
|
||||
|
||||
return pair.either(left(context), right(identity))
|
||||
|
||||
|
||||
def get_image_metadata(image_info):
|
||||
@curry(2)
|
||||
def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair:
|
||||
return ImageMetadataPair(image, metadatum)
|
||||
|
||||
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
|
||||
|
||||
def extract_image(doc: Doc, xref: int) -> Any:
|
||||
return compose(pixmap_to_image, extract_pixmap)(doc, xref)
|
||||
|
||||
|
||||
def pixmap_to_image(pixmap: Pxm) -> Img:
|
||||
array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n))
|
||||
array = normalize_channels(array)
|
||||
return Image.fromarray(array)
|
||||
|
||||
|
||||
def extract_pixmap(doc: Doc, xref: int) -> Pxm:
|
||||
try:
|
||||
return fitz.Pixmap(doc, xref)
|
||||
except ValueError as err:
|
||||
msg = f"Cross reference {xref} is invalid, skipping extraction."
|
||||
logger.error(err)
|
||||
logger.trace(msg)
|
||||
raise BadXref(msg) from err
|
||||
|
||||
|
||||
def has_alpha_channel(doc: Doc, xref: int) -> bool:
|
||||
|
||||
_get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc)
|
||||
_extract_pixmap = wrap_right(extract_pixmap)(doc)
|
||||
|
||||
def get_soft_mask_reference(cross_reference: int) -> Either:
|
||||
def error(value) -> str:
|
||||
return f"Invalid soft mask {value} for cross reference {cross_reference}."
|
||||
|
||||
logger.trace(f"Getting soft mask handle for cross reference {cross_reference}.")
|
||||
pass_on_if_not_none = iffy(notnone, right(identity), left(error))
|
||||
return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none)
|
||||
|
||||
def mask_exists(soft_mask_reference: int) -> Either:
|
||||
logger.trace(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.")
|
||||
return _get_image_handle(soft_mask_reference).then(notnone)
|
||||
|
||||
def image_has_alpha_channel(reference: int) -> Either:
|
||||
logger.trace(f"Checking if image with reference {reference} has alpha channel.")
|
||||
return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool)
|
||||
|
||||
logger.debug(f"Checking if image with cross reference {xref} has alpha channel.")
|
||||
|
||||
cross_reference = Right(xref)
|
||||
soft_mask_reference = cross_reference.bind(get_soft_mask_reference)
|
||||
|
||||
return any(
|
||||
take_good_log_bad(reference.bind(check))
|
||||
for reference, check in [
|
||||
(soft_mask_reference, mask_exists),
|
||||
(soft_mask_reference, image_has_alpha_channel),
|
||||
(cross_reference, image_has_alpha_channel),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
yield from filterfalse(tiny, metadata)
|
||||
|
||||
|
||||
def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
def __validate_box(box):
|
||||
try:
|
||||
return validate_box(box)
|
||||
except InvalidBox as err:
|
||||
logger.debug(f"Dropping invalid metadatum, reason: {err}")
|
||||
|
||||
yield from keep(__validate_box, metadata)
|
||||
|
||||
|
||||
def get_image_metadata(image_info: dict) -> dict:
|
||||
|
||||
xref, coords = itemgetter("xref", "bbox")(image_info)
|
||||
x1, y1, x2, y2 = map(rounder, coords)
|
||||
|
||||
width = abs(x2 - x1)
|
||||
height = abs(y2 - y1)
|
||||
@ -129,47 +270,40 @@ def get_image_metadata(image_info):
|
||||
Info.X2: x2,
|
||||
Info.Y1: y1,
|
||||
Info.Y2: y2,
|
||||
Info.XREF: xref,
|
||||
}
|
||||
|
||||
|
||||
def validate_coords_and_passthrough(metadata):
|
||||
yield from map(validate_box_coords, metadata)
|
||||
@lru_cache(maxsize=None)
|
||||
def get_image_infos(page: Pge) -> List[dict]:
|
||||
return page.get_image_info(xrefs=True)
|
||||
|
||||
|
||||
def filter_out_tiny_images(metadata):
|
||||
yield from filterfalse(tiny, metadata)
|
||||
|
||||
|
||||
def validate_size_and_passthrough(metadata):
|
||||
yield from map(validate_box_size, metadata)
|
||||
|
||||
|
||||
def add_page_metadata(page, metadata):
|
||||
def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]:
|
||||
yield from map(partial(merge, get_page_metadata(page)), metadata)
|
||||
|
||||
|
||||
def add_alpha_channel_info(doc, page, metadata):
|
||||
def normalize_channels(array: np.ndarray):
|
||||
if not array.ndim == 3:
|
||||
array = np.expand_dims(array, axis=-1)
|
||||
|
||||
page_to_xrefs = compose(curry(pluck)("xref"), get_image_infos)
|
||||
xref_to_alpha = partial(has_alpha_channel, doc)
|
||||
page_to_alpha_value_per_image = compose(lift(xref_to_alpha), page_to_xrefs)
|
||||
alpha_to_dict = compose(dict, lambda a: [(Info.ALPHA, a)])
|
||||
page_to_alpha_mapping_per_image = compose(lift(alpha_to_dict), page_to_alpha_value_per_image)
|
||||
if array.shape[-1] == 4:
|
||||
array = array[..., :3]
|
||||
elif array.shape[-1] == 1:
|
||||
array = np.concatenate([array, array, array], axis=-1)
|
||||
elif array.shape[-1] != 3:
|
||||
logger.warning(f"Unexpected image format: {array.shape}.")
|
||||
raise ValueError(f"Unexpected image format: {array.shape}.")
|
||||
|
||||
metadata = starmap(merge, zip(page_to_alpha_mapping_per_image(page), metadata))
|
||||
|
||||
yield from metadata
|
||||
return array
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def load_image_handle_from_xref(doc, xref):
|
||||
def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]:
|
||||
return doc.extract_image(xref)
|
||||
|
||||
|
||||
rounder = rcompose(round, int)
|
||||
|
||||
|
||||
def get_page_metadata(page):
|
||||
def get_page_metadata(page: Pge) -> dict:
|
||||
page_width, page_height = map(rounder, page.mediabox_size)
|
||||
|
||||
return {
|
||||
@ -179,30 +313,17 @@ def get_page_metadata(page):
|
||||
}
|
||||
|
||||
|
||||
def has_alpha_channel(doc, xref):
|
||||
|
||||
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||
maybe_smask = maybe_image["smask"] if maybe_image else None
|
||||
|
||||
if maybe_smask:
|
||||
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
|
||||
else:
|
||||
try:
|
||||
return bool(fitz.Pixmap(doc, xref).alpha)
|
||||
except ValueError:
|
||||
logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '<no title>')}.")
|
||||
return False
|
||||
rounder = rcompose(round, int)
|
||||
|
||||
|
||||
def tiny(metadata):
|
||||
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
|
||||
def tiny(metadatum: dict) -> bool:
|
||||
return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4
|
||||
|
||||
|
||||
def clear_caches():
|
||||
def clear_caches() -> None:
|
||||
get_image_infos.cache_clear()
|
||||
load_image_handle_from_xref.cache_clear()
|
||||
get_images_on_page.cache_clear()
|
||||
xref_to_image.cache_clear()
|
||||
get_image_handle.cache_clear()
|
||||
eith_extract_image.cache_clear()
|
||||
|
||||
|
||||
atexit.register(clear_caches)
|
||||
|
||||
@ -12,3 +12,4 @@ class Info(Enum):
|
||||
Y1 = "y1"
|
||||
Y2 = "y2"
|
||||
ALPHA = "alpha"
|
||||
XREF = "xref"
|
||||
|
||||
@ -3,15 +3,14 @@
|
||||
from pathlib import Path
|
||||
|
||||
MODULE_DIR = Path(__file__).resolve().parents[0]
|
||||
|
||||
PACKAGE_ROOT_DIR = MODULE_DIR.parents[0]
|
||||
|
||||
CONFIG_FILE = PACKAGE_ROOT_DIR / "config.yaml"
|
||||
|
||||
BANNER_FILE = PACKAGE_ROOT_DIR / "banner.txt"
|
||||
|
||||
DATA_DIR = PACKAGE_ROOT_DIR / "data"
|
||||
|
||||
MLRUNS_DIR = str(DATA_DIR / "mlruns")
|
||||
|
||||
TEST_DATA_DIR = PACKAGE_ROOT_DIR / "test" / "data"
|
||||
TEST_DIR = PACKAGE_ROOT_DIR / "test"
|
||||
TEST_DATA_DIR = TEST_DIR / "data"
|
||||
TEST_DATA_DIR_DVC = TEST_DIR / "data.dvc"
|
||||
|
||||
@ -11,8 +11,8 @@ from image_prediction.default_objects import (
|
||||
get_formatter,
|
||||
get_mlflow_model_loader,
|
||||
get_image_classifier,
|
||||
get_extractor,
|
||||
get_encoder,
|
||||
get_dispatched_extract,
|
||||
)
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
from image_prediction.utils.generic import lift, starlift
|
||||
@ -41,7 +41,7 @@ class Pipeline:
|
||||
def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs):
|
||||
self.verbose = verbose
|
||||
|
||||
extract = get_dispatched_extract(**kwargs)
|
||||
extract = get_extractor(**kwargs)
|
||||
classifier = get_image_classifier(model_loader, model_identifier)
|
||||
reformat = get_formatter()
|
||||
represent = get_encoder()
|
||||
@ -63,9 +63,9 @@ class Pipeline:
|
||||
reformat, # ... the items
|
||||
)
|
||||
|
||||
def __call__(self, pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None):
|
||||
def __call__(self, pdf: bytes, page_range: range = None):
|
||||
yield from tqdm(
|
||||
self.pipe(pdf, page_range=page_range, metadata_per_image=metadata_per_image),
|
||||
self.pipe(pdf, page_range=page_range),
|
||||
desc="Processing images from document",
|
||||
unit=" images",
|
||||
disable=not self.verbose,
|
||||
|
||||
@ -21,11 +21,6 @@ class ResponseTransformer(Transformer):
|
||||
|
||||
|
||||
def build_image_info(data: dict) -> dict:
|
||||
def compute_geometric_quotient():
|
||||
page_area_sqrt = math.sqrt(abs(page_width * page_height))
|
||||
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
||||
return image_area_sqrt / page_area_sqrt
|
||||
|
||||
page_width, page_height, x1, x2, y1, y2, width, height, alpha = itemgetter(
|
||||
"page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height", "alpha"
|
||||
)(data)
|
||||
@ -34,7 +29,7 @@ def build_image_info(data: dict) -> dict:
|
||||
label = classification["label"]
|
||||
representation = data["representation"]
|
||||
|
||||
geometric_quotient = round(compute_geometric_quotient(), 4)
|
||||
geometric_quotient = round(compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1), 4)
|
||||
|
||||
min_image_to_page_quotient_breached = bool(
|
||||
geometric_quotient < get_class_specific_min_image_to_page_quotient(label)
|
||||
@ -89,6 +84,12 @@ def build_image_info(data: dict) -> dict:
|
||||
return image_info
|
||||
|
||||
|
||||
def compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1):
|
||||
page_area_sqrt = math.sqrt(abs(page_width * page_height))
|
||||
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
||||
return image_area_sqrt / page_area_sqrt
|
||||
|
||||
|
||||
def get_class_specific_min_image_to_page_quotient(label, table=None):
|
||||
return get_class_specific_value(
|
||||
"REL_IMAGE_SIZE", label, "min", CONFIG.filters.image_to_page_quotient.min, table=table
|
||||
|
||||
@ -1,6 +1,15 @@
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
from itertools import starmap
|
||||
from typing import Callable
|
||||
|
||||
from funcy import iterate, first, curry, map
|
||||
from pymonad.either import Left, Right, Either
|
||||
from pymonad.tools import curry as pmcurry
|
||||
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def until(cond, func, *args, **kwargs):
|
||||
@ -13,3 +22,63 @@ def lift(fn):
|
||||
|
||||
def starlift(fn):
|
||||
return curry(starmap)(fn)
|
||||
|
||||
|
||||
def bottom(*args, **kwargs):
|
||||
return False
|
||||
|
||||
|
||||
def top(*args, **kwargs):
|
||||
return True
|
||||
|
||||
|
||||
def left(fn):
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
return Left(fn(x))
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def right(fn):
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
return Right(fn(x))
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def wrap_left(fn, success_condition=top, error_message=None) -> Callable:
|
||||
return wrap_either(Left, Right, success_condition=success_condition, error_message=error_message)(fn)
|
||||
|
||||
|
||||
def wrap_right(fn, success_condition=top, error_message=None) -> Callable:
|
||||
return wrap_either(Right, Left, success_condition=success_condition, error_message=error_message)(fn)
|
||||
|
||||
|
||||
def wrap_either(success_type, failure_type, success_condition=top, error_message=None) -> Callable:
|
||||
@wraps(wrap_either)
|
||||
def wrapper(fn) -> Callable:
|
||||
|
||||
n_params = len(signature(fn).parameters)
|
||||
|
||||
@pmcurry(n_params)
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs) -> Either:
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
if success_condition(result):
|
||||
return success_type(result)
|
||||
else:
|
||||
return failure_type({"error": error_message, "result": result})
|
||||
except Exception as err:
|
||||
logger.error(err)
|
||||
return failure_type({"error": error_message or err, "result": Void})
|
||||
|
||||
return wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Void:
|
||||
pass
|
||||
|
||||
@ -23,3 +23,5 @@ pdf2image==1.16.0
|
||||
frozendict==2.3.0
|
||||
protobuf<=3.20.*
|
||||
prometheus-client==0.13.1
|
||||
fsspec==2022.11.0
|
||||
PyMonad==2.4.0
|
||||
|
||||
@ -2,7 +2,6 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
from glob import glob
|
||||
from operator import truth
|
||||
|
||||
from image_prediction.pipeline import load_pipeline
|
||||
from image_prediction.utils import get_logger
|
||||
@ -15,7 +14,6 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("input", help="pdf file or directory")
|
||||
parser.add_argument("--metadata", help="optional figure detection metadata")
|
||||
parser.add_argument("--print", "-p", help="print output to terminal", action="store_true", default=False)
|
||||
parser.add_argument("--page_interval", "-i", help="page interval [i, j), min index = 0", nargs=2, type=int)
|
||||
|
||||
@ -24,17 +22,13 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def process_pdf(pipeline, pdf_path, metadata=None, page_range=None):
|
||||
if metadata:
|
||||
with open(metadata) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
def process_pdf(pipeline, pdf_path, page_range=None):
|
||||
with open(pdf_path, "rb") as f:
|
||||
logger.info(f"Processing {pdf_path}")
|
||||
predictions = list(pipeline(f.read(), page_range=page_range, metadata_per_image=metadata))
|
||||
predictions = list(pipeline(f.read(), page_range=page_range))
|
||||
|
||||
annotate_pdf(
|
||||
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf")))
|
||||
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf")))
|
||||
)
|
||||
|
||||
return predictions
|
||||
@ -48,10 +42,9 @@ def main(args):
|
||||
else:
|
||||
pdf_paths = glob(os.path.join(args.input, "*.pdf"))
|
||||
page_range = range(*args.page_interval) if args.page_interval else None
|
||||
metadata = args.metadata if args.metadata else None
|
||||
|
||||
for pdf_path in pdf_paths:
|
||||
predictions = process_pdf(pipeline, pdf_path, metadata, page_range=page_range)
|
||||
predictions = process_pdf(pipeline, pdf_path, page_range=page_range)
|
||||
if args.print:
|
||||
print(pdf_path)
|
||||
print(json.dumps(predictions, indent=2))
|
||||
|
||||
23
src/serve.py
23
src/serve.py
@ -1,5 +1,4 @@
|
||||
import gzip
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
|
||||
@ -29,36 +28,28 @@ logger.setLevel(PYINFRA_CONFIG.logging_level_root)
|
||||
def process_request(request_message):
|
||||
dossier_id = request_message["dossierId"]
|
||||
file_id = request_message["fileId"]
|
||||
logger.info(f"Processing {dossier_id=} {file_id=} ...")
|
||||
target_file_name = f"{dossier_id}/{file_id}.{request_message['targetFileExtension']}"
|
||||
response_file_name = f"{dossier_id}/{file_id}.{request_message['responseFileExtension']}"
|
||||
figure_data_file_name = f"{dossier_id}/{file_id}.FIGURE.json.gz"
|
||||
|
||||
bucket = PYINFRA_CONFIG.storage_bucket
|
||||
storage = get_storage(PYINFRA_CONFIG)
|
||||
|
||||
pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size)
|
||||
|
||||
if storage.exists(bucket, target_file_name):
|
||||
should_publish_result = True
|
||||
if not storage.exists(bucket, target_file_name):
|
||||
publish_result = False
|
||||
else:
|
||||
publish_result = True
|
||||
object_bytes = storage.get_object(bucket, target_file_name)
|
||||
object_bytes = gzip.decompress(object_bytes)
|
||||
classifications = list(pipeline(pdf=object_bytes))
|
||||
|
||||
if storage.exists(bucket, figure_data_file_name):
|
||||
metadata_bytes = storage.get_object(bucket, figure_data_file_name)
|
||||
metadata_bytes = gzip.decompress(metadata_bytes)
|
||||
metadata_per_image = json.load(io.BytesIO(metadata_bytes))["data"]
|
||||
classifications_cv = list(pipeline(pdf=object_bytes, metadata_per_image=metadata_per_image))
|
||||
else:
|
||||
classifications_cv = []
|
||||
|
||||
result = {**request_message, "data": classifications, "dataCV": classifications_cv}
|
||||
result = {**request_message, "data": classifications}
|
||||
storage_bytes = gzip.compress(json.dumps(result).encode("utf-8"))
|
||||
storage.put_object(bucket, response_file_name, storage_bytes)
|
||||
else:
|
||||
should_publish_result = False
|
||||
|
||||
return should_publish_result, {"dossierId": dossier_id, "fileId": file_id}
|
||||
return publish_result, {"dossierId": dossier_id, "fileId": file_id}
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
1
test/.gitignore
vendored
Normal file
1
test/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/data
|
||||
5
test/data.dvc
Normal file
5
test/data.dvc
Normal file
@ -0,0 +1,5 @@
|
||||
outs:
|
||||
- md5: 4b0fec291ce0661b3efbbd8b80f4f514.dir
|
||||
size: 107332
|
||||
nfiles: 4
|
||||
path: data
|
||||
Binary file not shown.
@ -1,44 +0,0 @@
|
||||
[
|
||||
{
|
||||
"classification": {
|
||||
"label": "formula",
|
||||
"probabilities": {
|
||||
"formula": 1.0,
|
||||
"logo": 0.0,
|
||||
"other": 0.0,
|
||||
"signature": 0.0
|
||||
}
|
||||
},
|
||||
"representation": "FFFEF0C7033648170F3EFFFFF",
|
||||
"position": {
|
||||
"x1": 321,
|
||||
"x2": 515,
|
||||
"y1": 348,
|
||||
"y2": 542,
|
||||
"pageNumber": 2
|
||||
},
|
||||
"geometry": {
|
||||
"width": 194,
|
||||
"height": 194
|
||||
},
|
||||
"alpha": false,
|
||||
"filters": {
|
||||
"geometry": {
|
||||
"imageSize": {
|
||||
"quotient": 0.2741,
|
||||
"tooLarge": false,
|
||||
"tooSmall": false
|
||||
},
|
||||
"imageFormat": {
|
||||
"quotient": 1.0,
|
||||
"tooTall": false,
|
||||
"tooWide": false
|
||||
}
|
||||
},
|
||||
"probability": {
|
||||
"unconfident": false
|
||||
},
|
||||
"allPassed": true
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -1,92 +0,0 @@
|
||||
{
|
||||
"input": [
|
||||
{
|
||||
"width": 100,
|
||||
"height": 8,
|
||||
"page_idx": 0,
|
||||
"page_width": 100,
|
||||
"page_height": 100,
|
||||
"x1": 0,
|
||||
"y1": 0,
|
||||
"x2": 100,
|
||||
"y2": 8
|
||||
},
|
||||
{
|
||||
"width": 100,
|
||||
"height": 9,
|
||||
"page_idx": 0,
|
||||
"page_width": 100,
|
||||
"page_height": 100,
|
||||
"x1": 0,
|
||||
"y1": 9,
|
||||
"x2": 100,
|
||||
"y2": 18
|
||||
},
|
||||
{
|
||||
"width": 100,
|
||||
"height": 35,
|
||||
"page_idx": 0,
|
||||
"page_width": 100,
|
||||
"page_height": 100,
|
||||
"x1": 0,
|
||||
"y1": 18,
|
||||
"x2": 100,
|
||||
"y2": 53
|
||||
},
|
||||
{
|
||||
"width": 47,
|
||||
"height": 46,
|
||||
"page_idx": 0,
|
||||
"page_width": 100,
|
||||
"page_height": 100,
|
||||
"x1": 0,
|
||||
"y1": 54,
|
||||
"x2": 47,
|
||||
"y2": 100
|
||||
},
|
||||
{
|
||||
"width": 31,
|
||||
"height": 46,
|
||||
"page_idx": 0,
|
||||
"page_width": 100,
|
||||
"page_height": 100,
|
||||
"x1": 48,
|
||||
"y1": 54,
|
||||
"x2": 79,
|
||||
"y2": 100
|
||||
},
|
||||
{
|
||||
"width": 20,
|
||||
"height": 19,
|
||||
"page_idx": 0,
|
||||
"page_width": 100,
|
||||
"page_height": 100,
|
||||
"x1": 80,
|
||||
"y1": 54,
|
||||
"x2": 100,
|
||||
"y2": 73
|
||||
},
|
||||
{
|
||||
"width": 20,
|
||||
"height": 27,
|
||||
"page_idx": 0,
|
||||
"page_width": 100,
|
||||
"page_height": 100,
|
||||
"x1": 80,
|
||||
"y1": 73,
|
||||
"x2": 100,
|
||||
"y2": 100
|
||||
}
|
||||
],
|
||||
"target": {
|
||||
"width": 100,
|
||||
"height": 100,
|
||||
"page_idx": 0,
|
||||
"page_width": 100,
|
||||
"page_height": 100,
|
||||
"x1": 0,
|
||||
"y1": 0,
|
||||
"x2": 100,
|
||||
"y2": 100
|
||||
}
|
||||
}
|
||||
14
test/fixtures/input.py
vendored
14
test/fixtures/input.py
vendored
@ -1,7 +1,21 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from dvc.repo import Repo
|
||||
|
||||
from image_prediction.locations import PACKAGE_ROOT_DIR, TEST_DATA_DIR_DVC
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def input_batch(batch_size, input_size):
|
||||
return np.random.random_sample(size=(batch_size, *input_size))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def dvc_test_data():
|
||||
logger.info("Pulling data with DVC...")
|
||||
# noinspection PyCallingNonCallable
|
||||
Repo(PACKAGE_ROOT_DIR).pull(targets=[str(TEST_DATA_DIR_DVC)])
|
||||
logger.info("Finished pulling data.")
|
||||
|
||||
12
test/fixtures/pdf.py
vendored
12
test/fixtures/pdf.py
vendored
@ -4,7 +4,7 @@ import fpdf
|
||||
import pytest
|
||||
|
||||
from image_prediction.locations import TEST_DATA_DIR
|
||||
from test.utils.generation.pdf import add_image, pdf_stream
|
||||
from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -18,6 +18,10 @@ def pdf(image_metadata_pairs):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_pdf():
|
||||
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
|
||||
yield f.read()
|
||||
def real_pdf(dvc_test_data):
|
||||
yield from stream_pdf_bytes(TEST_DATA_DIR / "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bad_xref_pdf(dvc_test_data):
|
||||
yield from stream_pdf_bytes(TEST_DATA_DIR / "bad_xref.pdf")
|
||||
|
||||
2
test/fixtures/target.py
vendored
2
test/fixtures/target.py
vendored
@ -87,7 +87,7 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_expected_service_response():
|
||||
def real_expected_service_response(dvc_test_data):
|
||||
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
|
||||
yield json.load(f)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import random
|
||||
from operator import itemgetter
|
||||
|
||||
@ -7,12 +8,23 @@ import pytest
|
||||
from PIL import Image
|
||||
from funcy import first, rest
|
||||
|
||||
from image_prediction.exceptions import BadXref
|
||||
from image_prediction.extraction import extract_images_from_pdf
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel
|
||||
from image_prediction.image_extractor.extractors.parsable import (
|
||||
extract_pages,
|
||||
has_alpha_channel,
|
||||
get_image_infos,
|
||||
ParsablePDFImageExtractor,
|
||||
extract_valid_metadata,
|
||||
eith_extract_image,
|
||||
extract_image,
|
||||
)
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.locations import TEST_DATA_DIR
|
||||
from test.utils.comparison import metadata_equal, image_sets_equal
|
||||
from test.utils.generation.pdf import add_image, pdf_stream
|
||||
from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||
@ -75,3 +87,15 @@ def test_has_alpha_channel(base_patch_metadata, suffix, mode):
|
||||
assert not list(rest(xrefs))
|
||||
|
||||
doc.close()
|
||||
|
||||
|
||||
def test_bad_xref_handling(bad_xref_pdf, dvc_test_data):
|
||||
|
||||
doc = fitz.Document(stream=bad_xref_pdf)
|
||||
metadata = extract_valid_metadata(doc, first(doc))
|
||||
xref = first(metadata)[Info.XREF]
|
||||
|
||||
with pytest.raises(BadXref):
|
||||
extract_image(doc, xref)
|
||||
|
||||
assert eith_extract_image(doc, xref).is_left()
|
||||
|
||||
@ -60,10 +60,10 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa
|
||||
assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4)
|
||||
|
||||
|
||||
def test_image_stitcher_with_gaps_must_succeed():
|
||||
def test_image_stitcher_with_gaps_must_succeed(dvc_test_data):
|
||||
from image_prediction.locations import TEST_DATA_DIR
|
||||
|
||||
with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f:
|
||||
with open(TEST_DATA_DIR / "stitching_with_tolerance.json") as f:
|
||||
patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f)))
|
||||
|
||||
images = map(gray_image_from_metadata, patches_metadata)
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
from functools import partial
|
||||
from itertools import starmap, product, repeat
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from frozendict import frozendict
|
||||
from funcy import ilen
|
||||
from funcy import ilen, compose, omit
|
||||
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.utils.generic import lift
|
||||
|
||||
|
||||
def transform_equal(a, b):
|
||||
@ -18,7 +21,8 @@ def images_equal(im1: Image, im2: Image, **kwargs):
|
||||
|
||||
|
||||
def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]):
|
||||
return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2))
|
||||
f = compose(set, lift(compose(frozendict, partial(omit, keys=[Info.XREF]))))
|
||||
return f(mdat1) == f(mdat2)
|
||||
|
||||
|
||||
def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]):
|
||||
|
||||
@ -28,3 +28,8 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix):
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image:
|
||||
image.save(temp_image.name)
|
||||
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix)
|
||||
|
||||
|
||||
def stream_pdf_bytes(path: str):
|
||||
with open(path, "rb") as f:
|
||||
yield f.read()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user