diff --git a/image_prediction/default_objects.py b/image_prediction/default_objects.py index a05e4f7..3ad5579 100644 --- a/image_prediction/default_objects.py +++ b/image_prediction/default_objects.py @@ -7,18 +7,16 @@ from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.compositor.compositor import TransformerCompositor from image_prediction.encoder.encoders.hash_encoder import HashEncoder from image_prediction.estimator.adapter.adapter import EstimatorAdapter +from image_prediction.formatter.formatter import format_image_plus from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter from image_prediction.formatter.formatters.enum import EnumFormatter -from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor -from image_prediction.info import Info from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.redai_adapter.mlflow import MlflowModelReader from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer from image_prediction.transformer.transformers.response import ResponseTransformer -from pdf2img.default_objects.image import ImagePlus from pdf2img.extraction import extract_images_via_metadata @@ -32,10 +30,23 @@ def get_image_classifier(model_loader, model_identifier): return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes))) -def get_extractor(**kwargs): +# def get_extractor(**kwargs): +# image_extractor = ParsablePDFImageExtractor(**kwargs) +# +# return image_extractor + + +def get_dispatched_extract(**kwargs): image_extractor = ParsablePDFImageExtractor(**kwargs) - return image_extractor + def extract(pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None): + if metadata_per_image: + image_pluses = extract_images_via_metadata(pdf, metadata_per_image) + yield from map(format_image_plus, image_pluses) + else: + yield from image_extractor.extract(pdf, page_range) + + return extract def get_formatter(): @@ -47,24 +58,3 @@ def get_formatter(): def get_encoder(): return HashEncoder() - - -def extract_images_via_metadata_and_format_to_image_metadata_pair(pdf: bytes, metadata_per_image: Iterable[dict]): - image_pluses = extract_images_via_metadata(pdf, metadata_per_image) - - def reformat(image: ImagePlus): - enum_metadata = { - Info.PAGE_WIDTH: image.info.pageInfo.width, - Info.PAGE_HEIGHT: image.info.pageInfo.height, - Info.PAGE_IDX: image.info.pageInfo.number, - Info.ALPHA: image.info.alpha, - Info.WIDTH: image.info.boundingBox.width, - Info.HEIGHT: image.info.boundingBox.height, - Info.X1: image.info.boundingBox.x0, - Info.X2: image.info.boundingBox.x1, - Info.Y1: image.info.boundingBox.y0, - Info.Y2: image.info.boundingBox.y1, - } - return ImageMetadataPair(image.aspil(), enum_metadata) - - yield from map(reformat, image_pluses) diff --git a/image_prediction/formatter/formatter.py b/image_prediction/formatter/formatter.py index 3f3a1f8..53306a9 100644 --- a/image_prediction/formatter/formatter.py +++ b/image_prediction/formatter/formatter.py @@ -1,6 +1,10 @@ import abc +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.info import Info + from image_prediction.transformer.transformer import Transformer +from pdf2img.default_objects.image import ImagePlus class Formatter(Transformer): @@ -13,3 +17,19 @@ class Formatter(Transformer): def __call__(self, obj): return self.format(obj) + + +def format_image_plus(image: ImagePlus) -> ImageMetadataPair: + enum_metadata = { + Info.PAGE_WIDTH: image.info.pageInfo.width, + Info.PAGE_HEIGHT: image.info.pageInfo.height, + Info.PAGE_IDX: image.info.pageInfo.number, + Info.ALPHA: image.info.alpha, + Info.WIDTH: image.info.boundingBox.width, + Info.HEIGHT: image.info.boundingBox.height, + Info.X1: image.info.boundingBox.x0, + Info.X2: image.info.boundingBox.x1, + Info.Y1: image.info.boundingBox.y0, + Info.Y2: image.info.boundingBox.y1, + } + return ImageMetadataPair(image.aspil(), enum_metadata) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 8958189..f9383a1 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -11,9 +11,8 @@ from image_prediction.default_objects import ( get_formatter, get_mlflow_model_loader, get_image_classifier, - get_extractor, get_encoder, - extract_images_via_metadata_and_format_to_image_metadata_pair, + get_dispatched_extract, ) from image_prediction.locations import MLRUNS_DIR from image_prediction.utils.generic import lift, starlift @@ -42,7 +41,7 @@ class Pipeline: def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs): self.verbose = verbose - extract = get_extractor(**kwargs) + extract = get_dispatched_extract(**kwargs) classifier = get_image_classifier(model_loader, model_identifier) reformat = get_formatter() represent = get_encoder() @@ -63,25 +62,10 @@ class Pipeline: join, # ... the streams by zipping reformat, # ... the items ) - self.pipe2 = rcompose( - extract_images_via_metadata_and_format_to_image_metadata_pair, - split, # ... into an image stream and a metadata stream - pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise - join, # ... the streams by zipping - reformat, # ... the items - ) - def __call__(self, pdf: bytes, page_range: range = None): + def __call__(self, pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None): yield from tqdm( - self.pipe(pdf, page_range=page_range), - desc="Processing images from document", - unit=" images", - disable=not self.verbose, - ) - - def extract_via_metadata(self, pdf: bytes, metadata_per_page: Iterable[dict]): - yield from tqdm( - self.pipe2(pdf, metadata_per_page), + self.pipe(pdf, page_range=page_range, metadata_per_image=metadata_per_image), desc="Processing images from document", unit=" images", disable=not self.verbose, diff --git a/src/serve.py b/src/serve.py index 0d972ec..866b37f 100644 --- a/src/serve.py +++ b/src/serve.py @@ -34,15 +34,12 @@ def process_request(request_message): try: # TODO: add figure detection file target to request message to avoid this metadata_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.FIGURE.json.gz") metadata_bytes = gzip.decompress(metadata_bytes) - metadata = json.load(io.BytesIO(metadata_bytes)) - logger.info("Metadata aquired") + metadata_per_image = json.load(io.BytesIO(metadata_bytes))["data"] + logger.info("Metadata acquired") except: - metadata = None + metadata_per_image = None - if metadata: - classifications = list(pipeline.extract_via_metadata(object_bytes, metadata_per_page=metadata["data"])) - else: - classifications = list(pipeline(object_bytes)) + classifications = list(pipeline(pdf=object_bytes, metadata_per_image=metadata_per_image)) result = {**request_message, "data": classifications}