diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 8718f0a..16189bc 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -1,13 +1,14 @@ import os +from itertools import starmap -from funcy import rcompose +from funcy import rcompose, juxt, compose from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.config import CONFIG from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier -from image_prediction.formatter.formatters.info_formatter import EnumFormatter +from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.locations import MLRUNS_DIR @@ -18,13 +19,15 @@ from image_prediction.redai_adapter.mlflow import MlflowModelReader os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -def get_image_classifier(): - model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR))) - model = model_loader.load_model(CONFIG.service.run_id) - classes = model_loader.load_classes(CONFIG.service.run_id) - label_mapper = ProbabilityMapper(classes) - classifier = Classifier(EstimatorAdapter(model), label_mapper) - image_classifier = ImageClassifier(classifier) +def get_image_classifier(model_loader, identifier): + image_classifier = compose(ImageClassifier, Classifier)( + *juxt( + *starmap( + compose, + [(EstimatorAdapter, model_loader.load_model), (ProbabilityMapper, model_loader.load_classes)], + ) + )(identifier) + ) return image_classifier @@ -35,8 +38,9 @@ def get_extractor(): return image_extractor -def get_extractor_classifier(): - extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier()) +def get_extractor_classifier(model_loader, identifier): + + extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, identifier)) return extractor_classifier @@ -49,7 +53,11 @@ def get_formatter(): class Pipeline: def __init__(self): - self.pipe = rcompose(get_extractor_classifier(), get_formatter()) + + model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR))) + identifier = CONFIG.service.run_id + + self.pipe = rcompose(get_extractor_classifier(model_loader, identifier), get_formatter()) def __call__(self, pdf: bytes): yield from self.pipe(pdf) diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 9668ce9..891e29f 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -26,4 +26,4 @@ def main(args): if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args)