refactoring
This commit is contained in:
parent
7ec7390e90
commit
4ebb36247e
20
image_prediction/default_objects.py
Normal file
20
image_prediction/default_objects.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from image_prediction.config import CONFIG
|
||||||
|
from image_prediction.locations import MLRUNS_DIR
|
||||||
|
from image_prediction.model_loader.loader import ModelLoader
|
||||||
|
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||||
|
from image_prediction.pipeline import Pipeline
|
||||||
|
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||||
|
|
||||||
|
|
||||||
|
def get_mlflow_model_loader(mlruns_dir):
|
||||||
|
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(mlruns_dir)))
|
||||||
|
return model_loader
|
||||||
|
|
||||||
|
|
||||||
|
def load_up_pipeline():
|
||||||
|
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
||||||
|
model_identifier = CONFIG.service.run_id
|
||||||
|
|
||||||
|
pipeline = Pipeline(model_loader, model_identifier)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
@ -1,35 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
from itertools import starmap
|
|
||||||
|
|
||||||
from funcy import rcompose, juxt, compose
|
from funcy import rcompose, juxt
|
||||||
|
|
||||||
from image_prediction.classifier.classifier import Classifier
|
from image_prediction.classifier.classifier import Classifier
|
||||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
from image_prediction.config import CONFIG
|
|
||||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
||||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||||
from image_prediction.locations import MLRUNS_DIR
|
|
||||||
from image_prediction.model_loader.loader import ModelLoader
|
|
||||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
|
||||||
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
|
||||||
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
|
||||||
def get_image_classifier(model_loader, identifier):
|
def get_image_classifier(model_loader, model_identifier):
|
||||||
image_classifier = compose(ImageClassifier, Classifier)(
|
model, classes = juxt(model_loader.load_model, model_loader.load_classes)(model_identifier)
|
||||||
*juxt(
|
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
|
||||||
*starmap(
|
|
||||||
compose,
|
|
||||||
[(EstimatorAdapter, model_loader.load_model), (ProbabilityMapper, model_loader.load_classes)],
|
|
||||||
)
|
|
||||||
)(identifier)
|
|
||||||
)
|
|
||||||
|
|
||||||
return image_classifier
|
|
||||||
|
|
||||||
|
|
||||||
def get_extractor():
|
def get_extractor():
|
||||||
@ -38,9 +24,8 @@ def get_extractor():
|
|||||||
return image_extractor
|
return image_extractor
|
||||||
|
|
||||||
|
|
||||||
def get_extractor_classifier(model_loader, identifier):
|
def get_extractor_classifier(model_loader, model_identifier):
|
||||||
|
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, model_identifier))
|
||||||
extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, identifier))
|
|
||||||
|
|
||||||
return extractor_classifier
|
return extractor_classifier
|
||||||
|
|
||||||
@ -52,12 +37,8 @@ def get_formatter():
|
|||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
def __init__(self):
|
def __init__(self, model_loader, model_identifier):
|
||||||
|
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier), get_formatter())
|
||||||
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(MLRUNS_DIR)))
|
|
||||||
identifier = CONFIG.service.run_id
|
|
||||||
|
|
||||||
self.pipe = rcompose(get_extractor_classifier(model_loader, identifier), get_formatter())
|
|
||||||
|
|
||||||
def __call__(self, pdf: bytes):
|
def __call__(self, pdf: bytes):
|
||||||
yield from self.pipe(pdf)
|
yield from self.pipe(pdf)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from image_prediction.pipeline import Pipeline
|
from image_prediction.default_objects import load_up_pipeline
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -14,14 +14,13 @@ def parse_args():
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
pipeline = load_up_pipeline()
|
||||||
pipeline = Pipeline()
|
|
||||||
|
|
||||||
with open(args.pdf, "rb") as f:
|
with open(args.pdf, "rb") as f:
|
||||||
predictions = pipeline(f.read())
|
predictions = pipeline(f.read())
|
||||||
|
|
||||||
for prd in predictions:
|
for prd in predictions:
|
||||||
print(prd)
|
print(json.dumps(prd, indent=1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user