refactoring

This commit is contained in:
Matthias Bisping 2022-03-31 14:49:46 +02:00
parent 7ec7390e90
commit 4ebb36247e
3 changed files with 31 additions and 31 deletions

View File

@ -0,0 +1,20 @@
from image_prediction.config import CONFIG
from image_prediction.locations import MLRUNS_DIR
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.pipeline import Pipeline
from image_prediction.redai_adapter.mlflow import MlflowModelReader
def get_mlflow_model_loader(mlruns_dir):
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(mlruns_dir)))
return model_loader
def load_up_pipeline():
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
model_identifier = CONFIG.service.run_id
pipeline = Pipeline(model_loader, model_identifier)
return pipeline

View File

@ -1,35 +1,21 @@
import os
from itertools import starmap
from funcy import rcompose, juxt, compose
from funcy import rcompose, juxt
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.config import CONFIG
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
from image_prediction.locations import MLRUNS_DIR
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
from image_prediction.redai_adapter.mlflow import MlflowModelReader
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
def get_image_classifier(model_loader, identifier):
image_classifier = compose(ImageClassifier, Classifier)(
*juxt(
*starmap(
compose,
[(EstimatorAdapter, model_loader.load_model), (ProbabilityMapper, model_loader.load_classes)],
)
)(identifier)
)
return image_classifier
def get_image_classifier(model_loader, model_identifier):
model, classes = juxt(model_loader.load_model, model_loader.load_classes)(model_identifier)
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
def get_extractor():
@ -38,9 +24,8 @@ def get_extractor():
return image_extractor
def get_extractor_classifier(model_loader, identifier):
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, identifier))
def get_extractor_classifier(model_loader, model_identifier):
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, model_identifier))
return extractor_classifier
@ -52,12 +37,8 @@ def get_formatter():
class Pipeline:
def __init__(self):
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
identifier = CONFIG.service.run_id
self.pipe = rcompose(get_extractor_classifier(model_loader, identifier), get_formatter())
def __init__(self, model_loader, model_identifier):
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier), get_formatter())
def __call__(self, pdf: bytes):
yield from self.pipe(pdf)

View File

@ -1,7 +1,7 @@
import argparse
import json
from image_prediction.pipeline import Pipeline
from image_prediction.default_objects import load_up_pipeline
def parse_args():
@ -14,14 +14,13 @@ def parse_args():
def main(args):
pipeline = Pipeline()
pipeline = load_up_pipeline()
with open(args.pdf, "rb") as f:
predictions = pipeline(f.read())
for prd in predictions:
print(prd)
print(json.dumps(prd, indent=1))
if __name__ == "__main__":