This commit is contained in:
Julius Unverfehrt 2022-08-10 13:27:41 +02:00
parent 287b0ebc8a
commit bae25bedbd
4 changed files with 44 additions and 53 deletions

View File

@ -7,18 +7,16 @@ from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.compositor.compositor import TransformerCompositor from image_prediction.compositor.compositor import TransformerCompositor
from image_prediction.encoder.encoders.hash_encoder import HashEncoder from image_prediction.encoder.encoders.hash_encoder import HashEncoder
from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.formatter.formatter import format_image_plus
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.info import Info
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader from image_prediction.redai_adapter.mlflow import MlflowModelReader
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
from image_prediction.transformer.transformers.response import ResponseTransformer from image_prediction.transformer.transformers.response import ResponseTransformer
from pdf2img.default_objects.image import ImagePlus
from pdf2img.extraction import extract_images_via_metadata from pdf2img.extraction import extract_images_via_metadata
@ -32,10 +30,23 @@ def get_image_classifier(model_loader, model_identifier):
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes))) return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
def get_extractor(**kwargs): # def get_extractor(**kwargs):
# image_extractor = ParsablePDFImageExtractor(**kwargs)
#
# return image_extractor
def get_dispatched_extract(**kwargs):
image_extractor = ParsablePDFImageExtractor(**kwargs) image_extractor = ParsablePDFImageExtractor(**kwargs)
return image_extractor def extract(pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None):
if metadata_per_image:
image_pluses = extract_images_via_metadata(pdf, metadata_per_image)
yield from map(format_image_plus, image_pluses)
else:
yield from image_extractor.extract(pdf, page_range)
return extract
def get_formatter(): def get_formatter():
@ -47,24 +58,3 @@ def get_formatter():
def get_encoder(): def get_encoder():
return HashEncoder() return HashEncoder()
def extract_images_via_metadata_and_format_to_image_metadata_pair(pdf: bytes, metadata_per_image: Iterable[dict]):
image_pluses = extract_images_via_metadata(pdf, metadata_per_image)
def reformat(image: ImagePlus):
enum_metadata = {
Info.PAGE_WIDTH: image.info.pageInfo.width,
Info.PAGE_HEIGHT: image.info.pageInfo.height,
Info.PAGE_IDX: image.info.pageInfo.number,
Info.ALPHA: image.info.alpha,
Info.WIDTH: image.info.boundingBox.width,
Info.HEIGHT: image.info.boundingBox.height,
Info.X1: image.info.boundingBox.x0,
Info.X2: image.info.boundingBox.x1,
Info.Y1: image.info.boundingBox.y0,
Info.Y2: image.info.boundingBox.y1,
}
return ImageMetadataPair(image.aspil(), enum_metadata)
yield from map(reformat, image_pluses)

View File

@ -1,6 +1,10 @@
import abc import abc
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.transformer.transformer import Transformer from image_prediction.transformer.transformer import Transformer
from pdf2img.default_objects.image import ImagePlus
class Formatter(Transformer): class Formatter(Transformer):
@ -13,3 +17,19 @@ class Formatter(Transformer):
def __call__(self, obj): def __call__(self, obj):
return self.format(obj) return self.format(obj)
def format_image_plus(image: ImagePlus) -> ImageMetadataPair:
enum_metadata = {
Info.PAGE_WIDTH: image.info.pageInfo.width,
Info.PAGE_HEIGHT: image.info.pageInfo.height,
Info.PAGE_IDX: image.info.pageInfo.number,
Info.ALPHA: image.info.alpha,
Info.WIDTH: image.info.boundingBox.width,
Info.HEIGHT: image.info.boundingBox.height,
Info.X1: image.info.boundingBox.x0,
Info.X2: image.info.boundingBox.x1,
Info.Y1: image.info.boundingBox.y0,
Info.Y2: image.info.boundingBox.y1,
}
return ImageMetadataPair(image.aspil(), enum_metadata)

View File

@ -11,9 +11,8 @@ from image_prediction.default_objects import (
get_formatter, get_formatter,
get_mlflow_model_loader, get_mlflow_model_loader,
get_image_classifier, get_image_classifier,
get_extractor,
get_encoder, get_encoder,
extract_images_via_metadata_and_format_to_image_metadata_pair, get_dispatched_extract,
) )
from image_prediction.locations import MLRUNS_DIR from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift from image_prediction.utils.generic import lift, starlift
@ -42,7 +41,7 @@ class Pipeline:
def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs): def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs):
self.verbose = verbose self.verbose = verbose
extract = get_extractor(**kwargs) extract = get_dispatched_extract(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier) classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter() reformat = get_formatter()
represent = get_encoder() represent = get_encoder()
@ -63,25 +62,10 @@ class Pipeline:
join, # ... the streams by zipping join, # ... the streams by zipping
reformat, # ... the items reformat, # ... the items
) )
self.pipe2 = rcompose(
extract_images_via_metadata_and_format_to_image_metadata_pair,
split, # ... into an image stream and a metadata stream
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
join, # ... the streams by zipping
reformat, # ... the items
)
def __call__(self, pdf: bytes, page_range: range = None): def __call__(self, pdf: bytes, page_range: range = None, metadata_per_image: Iterable[dict] = None):
yield from tqdm( yield from tqdm(
self.pipe(pdf, page_range=page_range), self.pipe(pdf, page_range=page_range, metadata_per_image=metadata_per_image),
desc="Processing images from document",
unit=" images",
disable=not self.verbose,
)
def extract_via_metadata(self, pdf: bytes, metadata_per_page: Iterable[dict]):
yield from tqdm(
self.pipe2(pdf, metadata_per_page),
desc="Processing images from document", desc="Processing images from document",
unit=" images", unit=" images",
disable=not self.verbose, disable=not self.verbose,

View File

@ -34,15 +34,12 @@ def process_request(request_message):
try: # TODO: add figure detection file target to request message to avoid this try: # TODO: add figure detection file target to request message to avoid this
metadata_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.FIGURE.json.gz") metadata_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.FIGURE.json.gz")
metadata_bytes = gzip.decompress(metadata_bytes) metadata_bytes = gzip.decompress(metadata_bytes)
metadata = json.load(io.BytesIO(metadata_bytes)) metadata_per_image = json.load(io.BytesIO(metadata_bytes))["data"]
logger.info("Metadata aquired") logger.info("Metadata acquired")
except: except:
metadata = None metadata_per_image = None
if metadata: classifications = list(pipeline(pdf=object_bytes, metadata_per_image=metadata_per_image))
classifications = list(pipeline.extract_via_metadata(object_bytes, metadata_per_page=metadata["data"]))
else:
classifications = list(pipeline(object_bytes))
result = {**request_message, "data": classifications} result = {**request_message, "data": classifications}