diff --git a/image_prediction/default_objects.py b/image_prediction/default_objects.py new file mode 100644 index 0000000..5580428 --- /dev/null +++ b/image_prediction/default_objects.py @@ -0,0 +1,20 @@ +from image_prediction.config import CONFIG +from image_prediction.locations import MLRUNS_DIR +from image_prediction.model_loader.loader import ModelLoader +from image_prediction.model_loader.loaders.mlflow import MlflowConnector +from image_prediction.pipeline import Pipeline +from image_prediction.redai_adapter.mlflow import MlflowModelReader + + +def get_mlflow_model_loader(mlruns_dir): + model_loader = ModelLoader(MlflowConnector(MlflowModelReader(mlruns_dir))) + return model_loader + + +def load_up_pipeline(): + model_loader = get_mlflow_model_loader(MLRUNS_DIR) + model_identifier = CONFIG.service.run_id + + pipeline = Pipeline(model_loader, model_identifier) + + return pipeline diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 16189bc..3f964d3 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -1,35 +1,21 @@ import os -from itertools import starmap -from funcy import rcompose, juxt, compose +from funcy import rcompose, juxt from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier -from image_prediction.config import CONFIG from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.label_mapper.mappers.probability import ProbabilityMapper -from image_prediction.locations import MLRUNS_DIR -from image_prediction.model_loader.loader import ModelLoader -from image_prediction.model_loader.loaders.mlflow import MlflowConnector -from image_prediction.redai_adapter.mlflow import MlflowModelReader os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -def get_image_classifier(model_loader, identifier): - image_classifier = compose(ImageClassifier, Classifier)( - *juxt( - *starmap( - compose, - [(EstimatorAdapter, model_loader.load_model), (ProbabilityMapper, model_loader.load_classes)], - ) - )(identifier) - ) - - return image_classifier +def get_image_classifier(model_loader, model_identifier): + model, classes = juxt(model_loader.load_model, model_loader.load_classes)(model_identifier) + return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes))) def get_extractor(): @@ -38,9 +24,8 @@ def get_extractor(): return image_extractor -def get_extractor_classifier(model_loader, identifier): - - extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, identifier)) +def get_extractor_classifier(model_loader, model_identifier): + extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, model_identifier)) return extractor_classifier @@ -52,12 +37,8 @@ def get_formatter(): class Pipeline: - def __init__(self): - - model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR))) - identifier = CONFIG.service.run_id - - self.pipe = rcompose(get_extractor_classifier(model_loader, identifier), get_formatter()) + def __init__(self, model_loader, model_identifier): + self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier), get_formatter()) def __call__(self, pdf: bytes): yield from self.pipe(pdf) diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 891e29f..0b22300 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -1,7 +1,7 @@ import argparse import json -from image_prediction.pipeline import Pipeline +from image_prediction.default_objects import load_up_pipeline def parse_args(): @@ -14,14 +14,13 @@ def parse_args(): def main(args): - - pipeline = Pipeline() + pipeline = load_up_pipeline() with open(args.pdf, "rb") as f: predictions = pipeline(f.read()) for prd in predictions: - print(prd) + print(json.dumps(prd, indent=1)) if __name__ == "__main__":