From 287b0ebc8a952e506185d13508eaa386d0420704 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Wed, 10 Aug 2022 12:57:35 +0200 Subject: [PATCH] update server logic for new pyinfra, add extraction from scanned PDF with figure detection logic --- image_prediction/default_objects.py | 27 +++++++++++++++++++++++++++ image_prediction/pipeline.py | 17 ++++++++++++++++- incl/pdf2image | 2 +- src/serve.py | 22 ++++++++++++++++++---- 4 files changed, 62 insertions(+), 6 deletions(-) diff --git a/image_prediction/default_objects.py b/image_prediction/default_objects.py index 1c40d56..a05e4f7 100644 --- a/image_prediction/default_objects.py +++ b/image_prediction/default_objects.py @@ -1,3 +1,5 @@ +from typing import Iterable + from funcy import juxt from image_prediction.classifier.classifier import Classifier @@ -7,13 +9,17 @@ from image_prediction.encoder.encoders.hash_encoder import HashEncoder from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter from image_prediction.formatter.formatters.enum import EnumFormatter +from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor +from image_prediction.info import Info from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.redai_adapter.mlflow import MlflowModelReader from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer from image_prediction.transformer.transformers.response import ResponseTransformer +from pdf2img.default_objects.image import ImagePlus +from pdf2img.extraction import extract_images_via_metadata def get_mlflow_model_loader(mlruns_dir): @@ -41,3 +47,24 @@ def get_formatter(): def get_encoder(): return HashEncoder() + + +def extract_images_via_metadata_and_format_to_image_metadata_pair(pdf: bytes, metadata_per_image: Iterable[dict]): + image_pluses = extract_images_via_metadata(pdf, metadata_per_image) + + def reformat(image: ImagePlus): + enum_metadata = { + Info.PAGE_WIDTH: image.info.pageInfo.width, + Info.PAGE_HEIGHT: image.info.pageInfo.height, + Info.PAGE_IDX: image.info.pageInfo.number, + Info.ALPHA: image.info.alpha, + Info.WIDTH: image.info.boundingBox.width, + Info.HEIGHT: image.info.boundingBox.height, + Info.X1: image.info.boundingBox.x0, + Info.X2: image.info.boundingBox.x1, + Info.Y1: image.info.boundingBox.y0, + Info.Y2: image.info.boundingBox.y1, + } + return ImageMetadataPair(image.aspil(), enum_metadata) + + yield from map(reformat, image_pluses) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 126c549..8958189 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -13,10 +13,10 @@ from image_prediction.default_objects import ( get_image_classifier, get_extractor, get_encoder, + extract_images_via_metadata_and_format_to_image_metadata_pair, ) from image_prediction.locations import MLRUNS_DIR from image_prediction.utils.generic import lift, starlift -from pdf2img.extraction import extract_images_per_page os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -63,6 +63,13 @@ class Pipeline: join, # ... the streams by zipping reformat, # ... the items ) + self.pipe2 = rcompose( + extract_images_via_metadata_and_format_to_image_metadata_pair, + split, # ... into an image stream and a metadata stream + pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise + join, # ... the streams by zipping + reformat, # ... the items + ) def __call__(self, pdf: bytes, page_range: range = None): yield from tqdm( @@ -71,3 +78,11 @@ class Pipeline: unit=" images", disable=not self.verbose, ) + + def extract_via_metadata(self, pdf: bytes, metadata_per_page: Iterable[dict]): + yield from tqdm( + self.pipe2(pdf, metadata_per_page), + desc="Processing images from document", + unit=" images", + disable=not self.verbose, + ) diff --git a/incl/pdf2image b/incl/pdf2image index 17965e4..6995688 160000 --- a/incl/pdf2image +++ b/incl/pdf2image @@ -1 +1 @@ -Subproject commit 17965e4578818b16cbd1638dfde1c58cbea55954 +Subproject commit 699568875683ba727ec9759c8bea85e0d3e1d369 diff --git a/src/serve.py b/src/serve.py index b541abf..0d972ec 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,11 +1,12 @@ import gzip +import io import json import logging from image_prediction.config import Config from image_prediction.locations import CONFIG_FILE from image_prediction.pipeline import load_pipeline -from image_prediction.utils.banner import show_banner, load_banner +from image_prediction.utils.banner import load_banner from pyinfra import config from pyinfra.queue.queue_manager import QueueManager from pyinfra.storage.storage import get_storage @@ -30,13 +31,26 @@ def process_request(request_message): object_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{target_file_extension}") object_bytes = gzip.decompress(object_bytes) - classifications = list(pipeline(object_bytes)) + try: # TODO: add figure detection file target to request message to avoid this + metadata_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.FIGURE.json.gz") + metadata_bytes = gzip.decompress(metadata_bytes) + metadata = json.load(io.BytesIO(metadata_bytes)) + logger.info("Metadata aquired") + except: + metadata = None + + if metadata: + classifications = list(pipeline.extract_via_metadata(object_bytes, metadata_per_page=metadata["data"])) + else: + classifications = list(pipeline(object_bytes)) result = {**request_message, "data": classifications} response_file_extension = request_message["responseFileExtension"] storage_bytes = gzip.compress(json.dumps(result).encode("utf-8")) - storage.put_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{response_file_extension}", storage_bytes) + storage.put_object( + PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{response_file_extension}", storage_bytes + ) return {"dossierId": dossier_id, "fileId": file_id} @@ -48,5 +62,5 @@ def main(): queue_manager.start_consuming(process_request) -if __name__ == '__main__': +if __name__ == "__main__": main()