refactoring

This commit is contained in:
Matthias Bisping 2022-03-31 12:52:35 +02:00
parent dc1cdde458
commit 7ec7390e90
2 changed files with 21 additions and 13 deletions

View File

@ -1,13 +1,14 @@
import os import os
from itertools import starmap
from funcy import rcompose from funcy import rcompose, juxt, compose
from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
from image_prediction.formatter.formatters.info_formatter import EnumFormatter from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
from image_prediction.locations import MLRUNS_DIR from image_prediction.locations import MLRUNS_DIR
@ -18,13 +19,15 @@ from image_prediction.redai_adapter.mlflow import MlflowModelReader
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
def get_image_classifier(): def get_image_classifier(model_loader, identifier):
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR))) image_classifier = compose(ImageClassifier, Classifier)(
model = model_loader.load_model(CONFIG.service.run_id) *juxt(
classes = model_loader.load_classes(CONFIG.service.run_id) *starmap(
label_mapper = ProbabilityMapper(classes) compose,
classifier = Classifier(EstimatorAdapter(model), label_mapper) [(EstimatorAdapter, model_loader.load_model), (ProbabilityMapper, model_loader.load_classes)],
image_classifier = ImageClassifier(classifier) )
)(identifier)
)
return image_classifier return image_classifier
@ -35,8 +38,9 @@ def get_extractor():
return image_extractor return image_extractor
def get_extractor_classifier(): def get_extractor_classifier(model_loader, identifier):
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier())
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, identifier))
return extractor_classifier return extractor_classifier
@ -49,7 +53,11 @@ def get_formatter():
class Pipeline: class Pipeline:
def __init__(self): def __init__(self):
self.pipe = rcompose(get_extractor_classifier(), get_formatter())
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
identifier = CONFIG.service.run_id
self.pipe = rcompose(get_extractor_classifier(model_loader, identifier), get_formatter())
def __call__(self, pdf: bytes): def __call__(self, pdf: bytes):
yield from self.pipe(pdf) yield from self.pipe(pdf)