refactoring
This commit is contained in:
parent
dc1cdde458
commit
7ec7390e90
@ -1,13 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
|
from itertools import starmap
|
||||||
|
|
||||||
from funcy import rcompose
|
from funcy import rcompose, juxt, compose
|
||||||
|
|
||||||
from image_prediction.classifier.classifier import Classifier
|
from image_prediction.classifier.classifier import Classifier
|
||||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
||||||
from image_prediction.formatter.formatters.info_formatter import EnumFormatter
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||||
from image_prediction.locations import MLRUNS_DIR
|
from image_prediction.locations import MLRUNS_DIR
|
||||||
@ -18,13 +19,15 @@ from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
|||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
|
||||||
def get_image_classifier():
|
def get_image_classifier(model_loader, identifier):
|
||||||
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
|
image_classifier = compose(ImageClassifier, Classifier)(
|
||||||
model = model_loader.load_model(CONFIG.service.run_id)
|
*juxt(
|
||||||
classes = model_loader.load_classes(CONFIG.service.run_id)
|
*starmap(
|
||||||
label_mapper = ProbabilityMapper(classes)
|
compose,
|
||||||
classifier = Classifier(EstimatorAdapter(model), label_mapper)
|
[(EstimatorAdapter, model_loader.load_model), (ProbabilityMapper, model_loader.load_classes)],
|
||||||
image_classifier = ImageClassifier(classifier)
|
)
|
||||||
|
)(identifier)
|
||||||
|
)
|
||||||
|
|
||||||
return image_classifier
|
return image_classifier
|
||||||
|
|
||||||
@ -35,8 +38,9 @@ def get_extractor():
|
|||||||
return image_extractor
|
return image_extractor
|
||||||
|
|
||||||
|
|
||||||
def get_extractor_classifier():
|
def get_extractor_classifier(model_loader, identifier):
|
||||||
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier())
|
|
||||||
|
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, identifier))
|
||||||
|
|
||||||
return extractor_classifier
|
return extractor_classifier
|
||||||
|
|
||||||
@ -49,7 +53,11 @@ def get_formatter():
|
|||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.pipe = rcompose(get_extractor_classifier(), get_formatter())
|
|
||||||
|
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
|
||||||
|
identifier = CONFIG.service.run_id
|
||||||
|
|
||||||
|
self.pipe = rcompose(get_extractor_classifier(model_loader, identifier), get_formatter())
|
||||||
|
|
||||||
def __call__(self, pdf: bytes):
|
def __call__(self, pdf: bytes):
|
||||||
yield from self.pipe(pdf)
|
yield from self.pipe(pdf)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user