refactoring
This commit is contained in:
parent
7ec7390e90
commit
4ebb36247e
20
image_prediction/default_objects.py
Normal file
20
image_prediction/default_objects.py
Normal file
@ -0,0 +1,20 @@
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
from image_prediction.pipeline import Pipeline
|
||||
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||
|
||||
|
||||
def get_mlflow_model_loader(mlruns_dir):
|
||||
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(mlruns_dir)))
|
||||
return model_loader
|
||||
|
||||
|
||||
def load_up_pipeline():
|
||||
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
||||
model_identifier = CONFIG.service.run_id
|
||||
|
||||
pipeline = Pipeline(model_loader, model_identifier)
|
||||
|
||||
return pipeline
|
||||
@ -1,35 +1,21 @@
|
||||
import os
|
||||
from itertools import starmap
|
||||
|
||||
from funcy import rcompose, juxt, compose
|
||||
from funcy import rcompose, juxt
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
|
||||
def get_image_classifier(model_loader, identifier):
|
||||
image_classifier = compose(ImageClassifier, Classifier)(
|
||||
*juxt(
|
||||
*starmap(
|
||||
compose,
|
||||
[(EstimatorAdapter, model_loader.load_model), (ProbabilityMapper, model_loader.load_classes)],
|
||||
)
|
||||
)(identifier)
|
||||
)
|
||||
|
||||
return image_classifier
|
||||
def get_image_classifier(model_loader, model_identifier):
|
||||
model, classes = juxt(model_loader.load_model, model_loader.load_classes)(model_identifier)
|
||||
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
|
||||
|
||||
|
||||
def get_extractor():
|
||||
@ -38,9 +24,8 @@ def get_extractor():
|
||||
return image_extractor
|
||||
|
||||
|
||||
def get_extractor_classifier(model_loader, identifier):
|
||||
|
||||
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, identifier))
|
||||
def get_extractor_classifier(model_loader, model_identifier):
|
||||
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, model_identifier))
|
||||
|
||||
return extractor_classifier
|
||||
|
||||
@ -52,12 +37,8 @@ def get_formatter():
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self):
|
||||
|
||||
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
|
||||
identifier = CONFIG.service.run_id
|
||||
|
||||
self.pipe = rcompose(get_extractor_classifier(model_loader, identifier), get_formatter())
|
||||
def __init__(self, model_loader, model_identifier):
|
||||
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier), get_formatter())
|
||||
|
||||
def __call__(self, pdf: bytes):
|
||||
yield from self.pipe(pdf)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from image_prediction.pipeline import Pipeline
|
||||
from image_prediction.default_objects import load_up_pipeline
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -14,14 +14,13 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline = load_up_pipeline()
|
||||
|
||||
with open(args.pdf, "rb") as f:
|
||||
predictions = pipeline(f.read())
|
||||
|
||||
for prd in predictions:
|
||||
print(prd)
|
||||
print(json.dumps(prd, indent=1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user