diff --git a/config.yaml b/config.yaml index 6a6111a..9bfcaf1 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,6 @@ webserver: - host: $SERVER_HOST|"127.0.0.1" # webserver address - port: $SERVER_PORT|5000 # webserver port + host: $SERVER_HOST|"127.0.0.1" # Webserver address + port: $SERVER_PORT|5000 # Webserver port service: logging_level: $LOGGING_LEVEL_ROOT|INFO # Logging level for service logger diff --git a/image_prediction/default_objects.py b/image_prediction/default_objects.py index d66d477..1c40d56 100644 --- a/image_prediction/default_objects.py +++ b/image_prediction/default_objects.py @@ -1,5 +1,3 @@ -from typing import Iterable - from funcy import juxt from image_prediction.classifier.classifier import Classifier @@ -7,7 +5,6 @@ from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.compositor.compositor import TransformerCompositor from image_prediction.encoder.encoders.hash_encoder import HashEncoder from image_prediction.estimator.adapter.adapter import EstimatorAdapter -from image_prediction.formatter.formatter import format_image_plus from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor @@ -17,7 +14,6 @@ from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.redai_adapter.mlflow import MlflowModelReader from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer from image_prediction.transformer.transformers.response import ResponseTransformer -from pdf2img.extraction import extract_images_via_metadata def get_mlflow_model_loader(mlruns_dir): @@ -30,17 +26,10 @@ def get_image_classifier(model_loader, model_identifier): return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes))) -def get_dispatched_extract(**kwargs): +def get_extractor(**kwargs): image_extractor = ParsablePDFImageExtractor(**kwargs) - def extract(pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None): - if metadata_per_image: - image_pluses = extract_images_via_metadata(pdf, metadata_per_image) - yield from map(format_image_plus, image_pluses) - else: - yield from image_extractor.extract(pdf, page_range) - - return extract + return image_extractor def get_formatter(): diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py index f03b42a..9c9ca49 100644 --- a/image_prediction/exceptions.py +++ b/image_prediction/exceptions.py @@ -36,3 +36,7 @@ class InvalidBox(Exception): class ParsingError(Exception): pass + + +class BadXref(ValueError): + pass diff --git a/image_prediction/formatter/formatter.py b/image_prediction/formatter/formatter.py index 53306a9..3f3a1f8 100644 --- a/image_prediction/formatter/formatter.py +++ b/image_prediction/formatter/formatter.py @@ -1,10 +1,6 @@ import abc -from image_prediction.image_extractor.extractor import ImageMetadataPair -from image_prediction.info import Info - from image_prediction.transformer.transformer import Transformer -from pdf2img.default_objects.image import ImagePlus class Formatter(Transformer): @@ -17,19 +13,3 @@ class Formatter(Transformer): def __call__(self, obj): return self.format(obj) - - -def format_image_plus(image: ImagePlus) -> ImageMetadataPair: - enum_metadata = { - Info.PAGE_WIDTH: image.info.pageInfo.width, - Info.PAGE_HEIGHT: image.info.pageInfo.height, - Info.PAGE_IDX: image.info.pageInfo.number, - Info.ALPHA: image.info.alpha, - Info.WIDTH: image.info.boundingBox.width, - Info.HEIGHT: image.info.boundingBox.height, - Info.X1: image.info.boundingBox.x0, - Info.X2: image.info.boundingBox.x1, - Info.Y1: image.info.boundingBox.y0, - Info.Y2: image.info.boundingBox.y1, - } - return ImageMetadataPair(image.aspil(), enum_metadata) diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index eac09e1..a911951 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -1,26 +1,32 @@ import atexit -import io -import json -import traceback from functools import partial, lru_cache -from itertools import chain, starmap, filterfalse -from operator import itemgetter, truth -from typing import List, Iterable, Iterator +from itertools import chain, filterfalse +from operator import itemgetter, attrgetter +from typing import List, Union, Iterable, Any import fitz +import numpy as np from PIL import Image -from funcy import rcompose, merge, pluck, curry, compose +from funcy import merge, compose, rcompose, keep, lkeep, notnone, iffy, rpartial +from pymonad.either import Right, Left, Either +from pymonad.tools import curry, identity +from image_prediction.exceptions import InvalidBox, BadXref from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.info import Info from image_prediction.stitching.stitching import stitch_pairs -from image_prediction.stitching.utils import validate_box_coords, validate_box_size +from image_prediction.stitching.utils import validate_box from image_prediction.utils import get_logger -from image_prediction.utils.generic import lift +from image_prediction.utils.generic import bottom, left, right, lift, wrap_right logger = get_logger() +Doc = fitz.fitz.Document +Pge = fitz.fitz.Page +Img = Image.Image +Pxm = fitz.fitz.Pixmap + class ParsablePDFImageExtractor(ImageExtractor): def __init__(self, verbose=False, tolerance=0): @@ -31,7 +37,7 @@ class ParsablePDFImageExtractor(ImageExtractor): tolerance: The tolerance in pixels for the distance between images, beyond which they will not be stitched together """ - self.doc: fitz.fitz.Document = None + self.doc: Union[Doc, None] = None self.verbose = verbose self.tolerance = tolerance @@ -44,80 +50,215 @@ class ParsablePDFImageExtractor(ImageExtractor): yield from image_metadata_pairs - def __process_images_on_page(self, page: fitz.fitz.Page): - images = get_images_on_page(self.doc, page) - metadata = get_metadata_for_images_on_page(self.doc, page) + def __process_images_on_page(self, page: Pge): + metadata = extract_valid_metadata(self.doc, page) + + either_image_metadata_pair_or_error_per_image = map(self.__metadatum_to_image_metadata_pair, metadata) + valid_image_metadata_pairs = lkeep( + rpartial(take_good_log_bad, format_context), either_image_metadata_pair_or_error_per_image + ) + valid_image_metadata_pairs_stitched = stitch_pairs(valid_image_metadata_pairs, tolerance=self.tolerance) + clear_caches() - image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata))) - # TODO: In the future, consider to introduce an image validator as a pipeline component rather than doing the - # validation here. Invalid images can then be split into a different stream and joined with the intact images - # again for the formatting step. - image_metadata_pairs = self.__filter_valid_images(image_metadata_pairs) - image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) + yield from valid_image_metadata_pairs_stitched - yield from image_metadata_pairs - - @staticmethod - def __filter_valid_images(image_metadata_pairs: Iterable[ImageMetadataPair]) -> Iterator[ImageMetadataPair]: - def validate(image: Image.Image, metadata: dict): - try: - # TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148) - image.resize((100, 100)).convert("RGB") - return ImageMetadataPair(image, metadata) - except (OSError, Exception) as err: - metadata = json.dumps(EnumFormatter()(metadata), indent=2) - logger.warning(f"Invalid image encountered. Image metadata:\n{metadata}\n\n{traceback.format_exc()}") - return None - - return filter(truth, starmap(validate, image_metadata_pairs)) + def __metadatum_to_image_metadata_pair(self, metadatum: dict) -> Either: + return metadatum_to_image_metadata_pair(self.doc, metadatum) -def extract_pages(doc, page_range): +def extract_pages(doc: Doc, page_range: range) -> Iterable[Pge]: page_range = range(page_range.start + 1, page_range.stop + 1) pages = map(doc.load_page, page_range) yield from pages -@lru_cache(maxsize=None) -def get_images_on_page(doc, page: fitz.Page): - image_infos = get_image_infos(page) - xrefs = map(itemgetter("xref"), image_infos) - images = map(partial(xref_to_image, doc), xrefs) - - yield from images +def validate_image(image: Img) -> Either: + try: + # TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148) + image.resize((100, 100)).convert("RGB") + return Right(image) + except (OSError, Exception): + logger.warning(f"Invalid image encountered.") + return Left("Invalid image.") -def get_metadata_for_images_on_page(doc, page: fitz.Page): +def extract_valid_metadata(doc: Doc, page: Pge) -> List[dict]: + return compose( + list, + partial(add_alpha_channel_info, doc), + filter_valid_metadata, + get_metadata_for_images_on_page, + )(page) - metadata = map(get_image_metadata, get_image_infos(page)) - metadata = validate_coords_and_passthrough(metadata) - metadata = filter_out_tiny_images(metadata) - metadata = validate_size_and_passthrough(metadata) +def take_good_log_bad(item: Either, log_formatter=identity) -> Any: + return item.either(rpartial(log_error_context, log_formatter), identity) - metadata = add_page_metadata(page, metadata) - metadata = add_alpha_channel_info(doc, page, metadata) +def format_context(context: dict) -> str: + return f"Reason: {context['reason'].rstrip('.')}. Metadata: {EnumFormatter()(context['metadata'])}" + + +def log_error_context(context: dict, formatter=identity) -> None: + logger.warning(f"Skipping bad image. {formatter(context)}") + return None + + +def metadatum_to_image_metadata_pair(doc: Doc, metadatum: dict) -> Either: + image: Either = eith_extract_image(doc, metadatum[Info.XREF]).bind(validate_image) + image_metadata_pair: Either = eith_make_image_metadata_pair(image, Right(metadatum)) + return image_metadata_pair + + +def add_alpha_channel_info(doc: Doc, metadata: Iterable[dict]) -> Iterable[dict]: + def add_alpha_value_to_metadatum(metadatum: dict) -> dict: + alpha = metadatum_to_alpha_value(metadatum) + return {**metadatum, Info.ALPHA: alpha} + + xref_to_alpha = partial(has_alpha_channel, doc) + metadatum_to_alpha_value = compose(xref_to_alpha, itemgetter(Info.XREF)) + + yield from map(add_alpha_value_to_metadatum, metadata) + + +def filter_valid_metadata(metadata: Iterable[dict]) -> Iterable[dict]: + yield from compose(filter_out_tiny_images, filter_out_invalid_metadata)(metadata) + + +def get_metadata_for_images_on_page(page: fitz.Page) -> Iterable[dict]: + metadata = compose( + partial(add_page_metadata, page), + lift(get_image_metadata), + get_image_infos, + )(page) yield from metadata @lru_cache(maxsize=None) -def get_image_infos(page: fitz.Page) -> List[dict]: - return page.get_image_info(xrefs=True) +@curry(2) +def eith_extract_image(doc: Doc, xref: int) -> Either: + try: + return Right(extract_image(doc, xref)) + except BadXref: + return Left("Bad xref.") -@lru_cache(maxsize=None) -def xref_to_image(doc, xref) -> Image: - maybe_image = load_image_handle_from_xref(doc, xref) - return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None +def eith_make_image_metadata_pair(image: Either, metadata: Either) -> Either: + """Reference: haskell.org/tutorial/monads.html""" + + def context(value): + return {"reason": value, "metadata": metadata.either(bottom, identity)} + + # Explicitly we are doing the following. (1) and (2) are equivalent. + + # a := Image + # b := Metadata + # c := ImageMetadataPair + # m := Either monad + + # fmt: off + # 1) + # pair: Either = ( + # Right(make_image_metadata_pair) # m (a -> b -> c) + # .amap(image) # m (a -> b -> c) <*> m a = m (b -> c) + # .amap(metadata) # m (b -> c) <*> m b = m c + # ) + + # 2) + # pair: Either = ( + # image.bind(right(make_image_metadata_pair)) # m a >>= m (a -> b -> c) = m (b -> c) + # .amap(metadata) # m (b -> c) <*> m b = m c + # ) + # fmt: on + + # Syntactic sugar variant with details hidden + pair: Either = Either.apply(make_image_metadata_pair).to_arguments(image, metadata) + + return pair.either(left(context), right(identity)) -def get_image_metadata(image_info): +@curry(2) +def make_image_metadata_pair(image: Img, metadatum: dict) -> ImageMetadataPair: + return ImageMetadataPair(image, metadatum) - x1, y1, x2, y2 = map(rounder, image_info["bbox"]) + +def extract_image(doc: Doc, xref: int) -> Any: + return compose(pixmap_to_image, extract_pixmap)(doc, xref) + + +def pixmap_to_image(pixmap: Pxm) -> Img: + array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape((pixmap.h, pixmap.w, pixmap.n)) + array = normalize_channels(array) + return Image.fromarray(array) + + +def extract_pixmap(doc: Doc, xref: int) -> Pxm: + try: + return fitz.Pixmap(doc, xref) + except ValueError as err: + msg = f"Cross reference {xref} is invalid, skipping extraction." + logger.error(err) + logger.trace(msg) + raise BadXref(msg) from err + + +def has_alpha_channel(doc: Doc, xref: int) -> bool: + + _get_image_handle = wrap_right(get_image_handle, success_condition=notnone)(doc) + _extract_pixmap = wrap_right(extract_pixmap)(doc) + + def get_soft_mask_reference(cross_reference: int) -> Either: + def error(value) -> str: + return f"Invalid soft mask {value} for cross reference {cross_reference}." + + logger.trace(f"Getting soft mask handle for cross reference {cross_reference}.") + pass_on_if_not_none = iffy(notnone, right(identity), left(error)) + return _get_image_handle(cross_reference).then(itemgetter("smask")).either(left(identity), pass_on_if_not_none) + + def mask_exists(soft_mask_reference: int) -> Either: + logger.trace(f"Checking if soft mask exists for soft mask reference {soft_mask_reference}.") + return _get_image_handle(soft_mask_reference).then(notnone) + + def image_has_alpha_channel(reference: int) -> Either: + logger.trace(f"Checking if image with reference {reference} has alpha channel.") + return _extract_pixmap(reference).then(attrgetter("alpha")).then(bool) + + logger.debug(f"Checking if image with cross reference {xref} has alpha channel.") + + cross_reference = Right(xref) + soft_mask_reference = cross_reference.bind(get_soft_mask_reference) + + return any( + take_good_log_bad(reference.bind(check)) + for reference, check in [ + (soft_mask_reference, mask_exists), + (soft_mask_reference, image_has_alpha_channel), + (cross_reference, image_has_alpha_channel), + ] + ) + + +def filter_out_tiny_images(metadata: Iterable[dict]) -> Iterable[dict]: + yield from filterfalse(tiny, metadata) + + +def filter_out_invalid_metadata(metadata: Iterable[dict]) -> Iterable[dict]: + def __validate_box(box): + try: + return validate_box(box) + except InvalidBox as err: + logger.debug(f"Dropping invalid metadatum, reason: {err}") + + yield from keep(__validate_box, metadata) + + +def get_image_metadata(image_info: dict) -> dict: + + xref, coords = itemgetter("xref", "bbox")(image_info) + x1, y1, x2, y2 = map(rounder, coords) width = abs(x2 - x1) height = abs(y2 - y1) @@ -129,47 +270,40 @@ def get_image_metadata(image_info): Info.X2: x2, Info.Y1: y1, Info.Y2: y2, + Info.XREF: xref, } -def validate_coords_and_passthrough(metadata): - yield from map(validate_box_coords, metadata) +@lru_cache(maxsize=None) +def get_image_infos(page: Pge) -> List[dict]: + return page.get_image_info(xrefs=True) -def filter_out_tiny_images(metadata): - yield from filterfalse(tiny, metadata) - - -def validate_size_and_passthrough(metadata): - yield from map(validate_box_size, metadata) - - -def add_page_metadata(page, metadata): +def add_page_metadata(page: Pge, metadata: Iterable[dict]) -> Iterable[dict]: yield from map(partial(merge, get_page_metadata(page)), metadata) -def add_alpha_channel_info(doc, page, metadata): +def normalize_channels(array: np.ndarray): + if not array.ndim == 3: + array = np.expand_dims(array, axis=-1) - page_to_xrefs = compose(curry(pluck)("xref"), get_image_infos) - xref_to_alpha = partial(has_alpha_channel, doc) - page_to_alpha_value_per_image = compose(lift(xref_to_alpha), page_to_xrefs) - alpha_to_dict = compose(dict, lambda a: [(Info.ALPHA, a)]) - page_to_alpha_mapping_per_image = compose(lift(alpha_to_dict), page_to_alpha_value_per_image) + if array.shape[-1] == 4: + array = array[..., :3] + elif array.shape[-1] == 1: + array = np.concatenate([array, array, array], axis=-1) + elif array.shape[-1] != 3: + logger.warning(f"Unexpected image format: {array.shape}.") + raise ValueError(f"Unexpected image format: {array.shape}.") - metadata = starmap(merge, zip(page_to_alpha_mapping_per_image(page), metadata)) - - yield from metadata + return array @lru_cache(maxsize=None) -def load_image_handle_from_xref(doc, xref): +def get_image_handle(doc: Doc, xref: int) -> Union[dict, None]: return doc.extract_image(xref) -rounder = rcompose(round, int) - - -def get_page_metadata(page): +def get_page_metadata(page: Pge) -> dict: page_width, page_height = map(rounder, page.mediabox_size) return { @@ -179,30 +313,17 @@ def get_page_metadata(page): } -def has_alpha_channel(doc, xref): - - maybe_image = load_image_handle_from_xref(doc, xref) - maybe_smask = maybe_image["smask"] if maybe_image else None - - if maybe_smask: - return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) - else: - try: - return bool(fitz.Pixmap(doc, xref).alpha) - except ValueError: - logger.debug(f"Encountered invalid xref `{xref}` in {doc.metadata.get('title', '')}.") - return False +rounder = rcompose(round, int) -def tiny(metadata): - return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4 +def tiny(metadatum: dict) -> bool: + return metadatum[Info.WIDTH] * metadatum[Info.HEIGHT] <= 4 -def clear_caches(): +def clear_caches() -> None: get_image_infos.cache_clear() - load_image_handle_from_xref.cache_clear() - get_images_on_page.cache_clear() - xref_to_image.cache_clear() + get_image_handle.cache_clear() + eith_extract_image.cache_clear() atexit.register(clear_caches) diff --git a/image_prediction/info.py b/image_prediction/info.py index 344274a..987779e 100644 --- a/image_prediction/info.py +++ b/image_prediction/info.py @@ -12,3 +12,4 @@ class Info(Enum): Y1 = "y1" Y2 = "y2" ALPHA = "alpha" + XREF = "xref" diff --git a/image_prediction/locations.py b/image_prediction/locations.py index 1f14c1a..9374ace 100644 --- a/image_prediction/locations.py +++ b/image_prediction/locations.py @@ -3,15 +3,14 @@ from pathlib import Path MODULE_DIR = Path(__file__).resolve().parents[0] - PACKAGE_ROOT_DIR = MODULE_DIR.parents[0] CONFIG_FILE = PACKAGE_ROOT_DIR / "config.yaml" - BANNER_FILE = PACKAGE_ROOT_DIR / "banner.txt" DATA_DIR = PACKAGE_ROOT_DIR / "data" - MLRUNS_DIR = str(DATA_DIR / "mlruns") -TEST_DATA_DIR = PACKAGE_ROOT_DIR / "test" / "data" +TEST_DIR = PACKAGE_ROOT_DIR / "test" +TEST_DATA_DIR = TEST_DIR / "data" +TEST_DATA_DIR_DVC = TEST_DIR / "data.dvc" diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index f9383a1..704a88f 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -11,8 +11,8 @@ from image_prediction.default_objects import ( get_formatter, get_mlflow_model_loader, get_image_classifier, + get_extractor, get_encoder, - get_dispatched_extract, ) from image_prediction.locations import MLRUNS_DIR from image_prediction.utils.generic import lift, starlift @@ -41,7 +41,7 @@ class Pipeline: def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs): self.verbose = verbose - extract = get_dispatched_extract(**kwargs) + extract = get_extractor(**kwargs) classifier = get_image_classifier(model_loader, model_identifier) reformat = get_formatter() represent = get_encoder() @@ -63,9 +63,9 @@ class Pipeline: reformat, # ... the items ) - def __call__(self, pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None): + def __call__(self, pdf: bytes, page_range: range = None): yield from tqdm( - self.pipe(pdf, page_range=page_range, metadata_per_image=metadata_per_image), + self.pipe(pdf, page_range=page_range), desc="Processing images from document", unit=" images", disable=not self.verbose, diff --git a/image_prediction/transformer/transformers/response.py b/image_prediction/transformer/transformers/response.py index 378fe7b..288c510 100644 --- a/image_prediction/transformer/transformers/response.py +++ b/image_prediction/transformer/transformers/response.py @@ -21,11 +21,6 @@ class ResponseTransformer(Transformer): def build_image_info(data: dict) -> dict: - def compute_geometric_quotient(): - page_area_sqrt = math.sqrt(abs(page_width * page_height)) - image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1)) - return image_area_sqrt / page_area_sqrt - page_width, page_height, x1, x2, y1, y2, width, height, alpha = itemgetter( "page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height", "alpha" )(data) @@ -34,7 +29,7 @@ def build_image_info(data: dict) -> dict: label = classification["label"] representation = data["representation"] - geometric_quotient = round(compute_geometric_quotient(), 4) + geometric_quotient = round(compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1), 4) min_image_to_page_quotient_breached = bool( geometric_quotient < get_class_specific_min_image_to_page_quotient(label) @@ -89,6 +84,12 @@ def build_image_info(data: dict) -> dict: return image_info +def compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1): + page_area_sqrt = math.sqrt(abs(page_width * page_height)) + image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1)) + return image_area_sqrt / page_area_sqrt + + def get_class_specific_min_image_to_page_quotient(label, table=None): return get_class_specific_value( "REL_IMAGE_SIZE", label, "min", CONFIG.filters.image_to_page_quotient.min, table=table diff --git a/image_prediction/utils/generic.py b/image_prediction/utils/generic.py index de71a5c..ffdf7b7 100644 --- a/image_prediction/utils/generic.py +++ b/image_prediction/utils/generic.py @@ -1,6 +1,15 @@ +from functools import wraps +from inspect import signature from itertools import starmap +from typing import Callable from funcy import iterate, first, curry, map +from pymonad.either import Left, Right, Either +from pymonad.tools import curry as pmcurry + +from image_prediction.utils import get_logger + +logger = get_logger() def until(cond, func, *args, **kwargs): @@ -13,3 +22,63 @@ def lift(fn): def starlift(fn): return curry(starmap)(fn) + + +def bottom(*args, **kwargs): + return False + + +def top(*args, **kwargs): + return True + + +def left(fn): + @wraps(fn) + def inner(x): + return Left(fn(x)) + + return inner + + +def right(fn): + @wraps(fn) + def inner(x): + return Right(fn(x)) + + return inner + + +def wrap_left(fn, success_condition=top, error_message=None) -> Callable: + return wrap_either(Left, Right, success_condition=success_condition, error_message=error_message)(fn) + + +def wrap_right(fn, success_condition=top, error_message=None) -> Callable: + return wrap_either(Right, Left, success_condition=success_condition, error_message=error_message)(fn) + + +def wrap_either(success_type, failure_type, success_condition=top, error_message=None) -> Callable: + @wraps(wrap_either) + def wrapper(fn) -> Callable: + + n_params = len(signature(fn).parameters) + + @pmcurry(n_params) + @wraps(fn) + def wrapper(*args, **kwargs) -> Either: + try: + result = fn(*args, **kwargs) + if success_condition(result): + return success_type(result) + else: + return failure_type({"error": error_message, "result": result}) + except Exception as err: + logger.error(err) + return failure_type({"error": error_message or err, "result": Void}) + + return wrapper + + return wrapper + + +class Void: + pass diff --git a/requirements.txt b/requirements.txt index da99202..3559e63 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,5 @@ pdf2image==1.16.0 frozendict==2.3.0 protobuf<=3.20.* prometheus-client==0.13.1 +fsspec==2022.11.0 +PyMonad==2.4.0 diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 29d3199..c2b4bb0 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -2,7 +2,6 @@ import argparse import json import os from glob import glob -from operator import truth from image_prediction.pipeline import load_pipeline from image_prediction.utils import get_logger @@ -15,7 +14,6 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("input", help="pdf file or directory") - parser.add_argument("--metadata", help="optional figure detection metadata") parser.add_argument("--print", "-p", help="print output to terminal", action="store_true", default=False) parser.add_argument("--page_interval", "-i", help="page interval [i, j), min index = 0", nargs=2, type=int) @@ -24,17 +22,13 @@ def parse_args(): return args -def process_pdf(pipeline, pdf_path, metadata=None, page_range=None): - if metadata: - with open(metadata) as f: - metadata = json.load(f) - +def process_pdf(pipeline, pdf_path, page_range=None): with open(pdf_path, "rb") as f: logger.info(f"Processing {pdf_path}") - predictions = list(pipeline(f.read(), page_range=page_range, metadata_per_image=metadata)) + predictions = list(pipeline(f.read(), page_range=page_range)) annotate_pdf( - pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", f"_{truth(metadata)}_annotated.pdf"))) + pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf"))) ) return predictions @@ -48,10 +42,9 @@ def main(args): else: pdf_paths = glob(os.path.join(args.input, "*.pdf")) page_range = range(*args.page_interval) if args.page_interval else None - metadata = args.metadata if args.metadata else None for pdf_path in pdf_paths: - predictions = process_pdf(pipeline, pdf_path, metadata, page_range=page_range) + predictions = process_pdf(pipeline, pdf_path, page_range=page_range) if args.print: print(pdf_path) print(json.dumps(predictions, indent=2)) diff --git a/src/serve.py b/src/serve.py index ece6a0b..a865c7d 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,5 +1,4 @@ import gzip -import io import json import logging @@ -29,36 +28,28 @@ logger.setLevel(PYINFRA_CONFIG.logging_level_root) def process_request(request_message): dossier_id = request_message["dossierId"] file_id = request_message["fileId"] + logger.info(f"Processing {dossier_id=} {file_id=} ...") target_file_name = f"{dossier_id}/{file_id}.{request_message['targetFileExtension']}" response_file_name = f"{dossier_id}/{file_id}.{request_message['responseFileExtension']}" - figure_data_file_name = f"{dossier_id}/{file_id}.FIGURE.json.gz" bucket = PYINFRA_CONFIG.storage_bucket storage = get_storage(PYINFRA_CONFIG) pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size) - if storage.exists(bucket, target_file_name): - should_publish_result = True + if not storage.exists(bucket, target_file_name): + publish_result = False + else: + publish_result = True object_bytes = storage.get_object(bucket, target_file_name) object_bytes = gzip.decompress(object_bytes) classifications = list(pipeline(pdf=object_bytes)) - if storage.exists(bucket, figure_data_file_name): - metadata_bytes = storage.get_object(bucket, figure_data_file_name) - metadata_bytes = gzip.decompress(metadata_bytes) - metadata_per_image = json.load(io.BytesIO(metadata_bytes))["data"] - classifications_cv = list(pipeline(pdf=object_bytes, metadata_per_image=metadata_per_image)) - else: - classifications_cv = [] - - result = {**request_message, "data": classifications, "dataCV": classifications_cv} + result = {**request_message, "data": classifications} storage_bytes = gzip.compress(json.dumps(result).encode("utf-8")) storage.put_object(bucket, response_file_name, storage_bytes) - else: - should_publish_result = False - return should_publish_result, {"dossierId": dossier_id, "fileId": file_id} + return publish_result, {"dossierId": dossier_id, "fileId": file_id} def main(): diff --git a/test/.gitignore b/test/.gitignore new file mode 100644 index 0000000..3af0ccb --- /dev/null +++ b/test/.gitignore @@ -0,0 +1 @@ +/data diff --git a/test/data.dvc b/test/data.dvc new file mode 100644 index 0000000..c7040fe --- /dev/null +++ b/test/data.dvc @@ -0,0 +1,5 @@ +outs: +- md5: 4b0fec291ce0661b3efbbd8b80f4f514.dir + size: 107332 + nfiles: 4 + path: data diff --git a/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf deleted file mode 100644 index 41f0d70..0000000 Binary files a/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf and /dev/null differ diff --git a/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json deleted file mode 100644 index 1a1b3f5..0000000 --- a/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json +++ /dev/null @@ -1,44 +0,0 @@ -[ - { - "classification": { - "label": "formula", - "probabilities": { - "formula": 1.0, - "logo": 0.0, - "other": 0.0, - "signature": 0.0 - } - }, - "representation": "FFFEF0C7033648170F3EFFFFF", - "position": { - "x1": 321, - "x2": 515, - "y1": 348, - "y2": 542, - "pageNumber": 2 - }, - "geometry": { - "width": 194, - "height": 194 - }, - "alpha": false, - "filters": { - "geometry": { - "imageSize": { - "quotient": 0.2741, - "tooLarge": false, - "tooSmall": false - }, - "imageFormat": { - "quotient": 1.0, - "tooTall": false, - "tooWide": false - } - }, - "probability": { - "unconfident": false - }, - "allPassed": true - } - } -] \ No newline at end of file diff --git a/test/data/stitching_with_tolerance.json b/test/data/stitching_with_tolerance.json deleted file mode 100644 index f7f1049..0000000 --- a/test/data/stitching_with_tolerance.json +++ /dev/null @@ -1,92 +0,0 @@ -{ - "input": [ - { - "width": 100, - "height": 8, - "page_idx": 0, - "page_width": 100, - "page_height": 100, - "x1": 0, - "y1": 0, - "x2": 100, - "y2": 8 - }, - { - "width": 100, - "height": 9, - "page_idx": 0, - "page_width": 100, - "page_height": 100, - "x1": 0, - "y1": 9, - "x2": 100, - "y2": 18 - }, - { - "width": 100, - "height": 35, - "page_idx": 0, - "page_width": 100, - "page_height": 100, - "x1": 0, - "y1": 18, - "x2": 100, - "y2": 53 - }, - { - "width": 47, - "height": 46, - "page_idx": 0, - "page_width": 100, - "page_height": 100, - "x1": 0, - "y1": 54, - "x2": 47, - "y2": 100 - }, - { - "width": 31, - "height": 46, - "page_idx": 0, - "page_width": 100, - "page_height": 100, - "x1": 48, - "y1": 54, - "x2": 79, - "y2": 100 - }, - { - "width": 20, - "height": 19, - "page_idx": 0, - "page_width": 100, - "page_height": 100, - "x1": 80, - "y1": 54, - "x2": 100, - "y2": 73 - }, - { - "width": 20, - "height": 27, - "page_idx": 0, - "page_width": 100, - "page_height": 100, - "x1": 80, - "y1": 73, - "x2": 100, - "y2": 100 - } - ], - "target": { - "width": 100, - "height": 100, - "page_idx": 0, - "page_width": 100, - "page_height": 100, - "x1": 0, - "y1": 0, - "x2": 100, - "y2": 100 - } -} diff --git a/test/fixtures/input.py b/test/fixtures/input.py index b02f414..2054df6 100644 --- a/test/fixtures/input.py +++ b/test/fixtures/input.py @@ -1,7 +1,21 @@ import numpy as np import pytest +from dvc.repo import Repo + +from image_prediction.locations import PACKAGE_ROOT_DIR, TEST_DATA_DIR_DVC +from image_prediction.utils import get_logger + +logger = get_logger() @pytest.fixture def input_batch(batch_size, input_size): return np.random.random_sample(size=(batch_size, *input_size)) + + +@pytest.fixture(scope="session") +def dvc_test_data(): + logger.info("Pulling data with DVC...") + # noinspection PyCallingNonCallable + Repo(PACKAGE_ROOT_DIR).pull(targets=[str(TEST_DATA_DIR_DVC)]) + logger.info("Finished pulling data.") diff --git a/test/fixtures/pdf.py b/test/fixtures/pdf.py index 7353917..0991bbe 100644 --- a/test/fixtures/pdf.py +++ b/test/fixtures/pdf.py @@ -4,7 +4,7 @@ import fpdf import pytest from image_prediction.locations import TEST_DATA_DIR -from test.utils.generation.pdf import add_image, pdf_stream +from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes @pytest.fixture @@ -18,6 +18,10 @@ def pdf(image_metadata_pairs): @pytest.fixture -def real_pdf(): - with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f: - yield f.read() +def real_pdf(dvc_test_data): + yield from stream_pdf_bytes(TEST_DATA_DIR / "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf") + + +@pytest.fixture +def bad_xref_pdf(dvc_test_data): + yield from stream_pdf_bytes(TEST_DATA_DIR / "bad_xref.pdf") diff --git a/test/fixtures/target.py b/test/fixtures/target.py index 23f23bd..1f111fc 100644 --- a/test/fixtures/target.py +++ b/test/fixtures/target.py @@ -87,7 +87,7 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped): @pytest.fixture -def real_expected_service_response(): +def real_expected_service_response(dvc_test_data): with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f: yield json.load(f) diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index e52b2b5..1c34103 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -1,3 +1,4 @@ +import json import random from operator import itemgetter @@ -7,12 +8,23 @@ import pytest from PIL import Image from funcy import first, rest +from image_prediction.exceptions import BadXref from image_prediction.extraction import extract_images_from_pdf +from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.image_extractor.extractor import ImageMetadataPair -from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel +from image_prediction.image_extractor.extractors.parsable import ( + extract_pages, + has_alpha_channel, + get_image_infos, + ParsablePDFImageExtractor, + extract_valid_metadata, + eith_extract_image, + extract_image, +) from image_prediction.info import Info +from image_prediction.locations import TEST_DATA_DIR from test.utils.comparison import metadata_equal, image_sets_equal -from test.utils.generation.pdf import add_image, pdf_stream +from test.utils.generation.pdf import add_image, pdf_stream, stream_pdf_bytes @pytest.mark.parametrize("extractor_type", ["mock"]) @@ -75,3 +87,15 @@ def test_has_alpha_channel(base_patch_metadata, suffix, mode): assert not list(rest(xrefs)) doc.close() + + +def test_bad_xref_handling(bad_xref_pdf, dvc_test_data): + + doc = fitz.Document(stream=bad_xref_pdf) + metadata = extract_valid_metadata(doc, first(doc)) + xref = first(metadata)[Info.XREF] + + with pytest.raises(BadXref): + extract_image(doc, xref) + + assert eith_extract_image(doc, xref).is_left() diff --git a/test/unit_tests/image_stitching_test.py b/test/unit_tests/image_stitching_test.py index edf7923..3762036 100644 --- a/test/unit_tests/image_stitching_test.py +++ b/test/unit_tests/image_stitching_test.py @@ -60,10 +60,10 @@ def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_pa assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4) -def test_image_stitcher_with_gaps_must_succeed(): +def test_image_stitcher_with_gaps_must_succeed(dvc_test_data): from image_prediction.locations import TEST_DATA_DIR - with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f: + with open(TEST_DATA_DIR / "stitching_with_tolerance.json") as f: patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f))) images = map(gray_image_from_metadata, patches_metadata) diff --git a/test/utils/comparison.py b/test/utils/comparison.py index f2677ce..b4c8d14 100644 --- a/test/utils/comparison.py +++ b/test/utils/comparison.py @@ -1,12 +1,15 @@ +from functools import partial from itertools import starmap, product, repeat from typing import Iterable import numpy as np from PIL.Image import Image from frozendict import frozendict -from funcy import ilen +from funcy import ilen, compose, omit from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor +from image_prediction.info import Info +from image_prediction.utils.generic import lift def transform_equal(a, b): @@ -18,7 +21,8 @@ def images_equal(im1: Image, im2: Image, **kwargs): def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]): - return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2)) + f = compose(set, lift(compose(frozendict, partial(omit, keys=[Info.XREF])))) + return f(mdat1) == f(mdat2) def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]): diff --git a/test/utils/generation/pdf.py b/test/utils/generation/pdf.py index 852647e..111a6d4 100644 --- a/test/utils/generation/pdf.py +++ b/test/utils/generation/pdf.py @@ -28,3 +28,8 @@ def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix): with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image: image.save(temp_image.name) pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix) + + +def stream_pdf_bytes(path: str): + with open(path, "rb") as f: + yield f.read()