71 lines
2.9 KiB
Python
71 lines
2.9 KiB
Python
from typing import Iterable
|
|
|
|
from funcy import juxt
|
|
|
|
from image_prediction.classifier.classifier import Classifier
|
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
|
from image_prediction.compositor.compositor import TransformerCompositor
|
|
from image_prediction.encoder.encoders.hash_encoder import HashEncoder
|
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
|
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
|
from image_prediction.info import Info
|
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
|
from image_prediction.model_loader.loader import ModelLoader
|
|
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
|
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
|
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
|
|
from image_prediction.transformer.transformers.response import ResponseTransformer
|
|
from pdf2img.default_objects.image import ImagePlus
|
|
from pdf2img.extraction import extract_images_via_metadata
|
|
|
|
|
|
def get_mlflow_model_loader(mlruns_dir):
|
|
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(mlruns_dir)))
|
|
return model_loader
|
|
|
|
|
|
def get_image_classifier(model_loader, model_identifier):
|
|
model, classes = juxt(model_loader.load_model, model_loader.load_classes)(model_identifier)
|
|
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
|
|
|
|
|
|
def get_extractor(**kwargs):
|
|
image_extractor = ParsablePDFImageExtractor(**kwargs)
|
|
|
|
return image_extractor
|
|
|
|
|
|
def get_formatter():
|
|
formatter = TransformerCompositor(
|
|
PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()
|
|
)
|
|
return formatter
|
|
|
|
|
|
def get_encoder():
|
|
return HashEncoder()
|
|
|
|
|
|
def extract_images_via_metadata_and_format_to_image_metadata_pair(pdf: bytes, metadata_per_image: Iterable[dict]):
|
|
image_pluses = extract_images_via_metadata(pdf, metadata_per_image)
|
|
|
|
def reformat(image: ImagePlus):
|
|
enum_metadata = {
|
|
Info.PAGE_WIDTH: image.info.pageInfo.width,
|
|
Info.PAGE_HEIGHT: image.info.pageInfo.height,
|
|
Info.PAGE_IDX: image.info.pageInfo.number,
|
|
Info.ALPHA: image.info.alpha,
|
|
Info.WIDTH: image.info.boundingBox.width,
|
|
Info.HEIGHT: image.info.boundingBox.height,
|
|
Info.X1: image.info.boundingBox.x0,
|
|
Info.X2: image.info.boundingBox.x1,
|
|
Info.Y1: image.info.boundingBox.y0,
|
|
Info.Y2: image.info.boundingBox.y1,
|
|
}
|
|
return ImageMetadataPair(image.aspil(), enum_metadata)
|
|
|
|
yield from map(reformat, image_pluses)
|