from funcy import juxt from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.compositor.compositor import TransformerCompositor from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.formatter.formatters.response import ResponseTransformer from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loaders.mlflow import MlflowConnector from image_prediction.redai_adapter.mlflow import MlflowModelReader def get_mlflow_model_loader(mlruns_dir): model_loader = ModelLoader(MlflowConnector(MlflowModelReader(mlruns_dir))) return model_loader def get_image_classifier(model_loader, model_identifier): model, classes = juxt(model_loader.load_model, model_loader.load_classes)(model_identifier) return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes))) def get_extractor(**kwargs): image_extractor = ParsablePDFImageExtractor(**kwargs) return image_extractor def get_extractor_classifier(model_loader, model_identifier, **kwargs): extractor_classifier = ExtractorClassifier( get_extractor(**kwargs), get_image_classifier(model_loader, model_identifier) ) return extractor_classifier def get_formatter(): formatter = TransformerCompositor(EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()) return formatter