This commit is contained in:
Matthias Bisping 2022-03-27 17:55:01 +02:00
parent 0f811bdc56
commit 9d58ae714f
7 changed files with 18 additions and 18 deletions

View File

@ -8,13 +8,14 @@ from image_prediction.utils import get_logger
logger = get_logger()
class Estimator:
class Classifier:
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
an EstimatorAdapter must be implemented.
Args:
estimator_adapter: adapter for a given estimator backend
estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns numeric
labels as predictions
classes: mapping from a numerical label to a human-readable label for classes
"""
self.__estimator_adapter = estimator_adapter

View File

@ -3,18 +3,18 @@ from typing import List
from PIL.Image import Image
from image_prediction.estimator.estimator import Estimator
from image_prediction.classifier.classifier import Classifier
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.utils import chunk_iterable
class Predictor:
def __init__(self, estimator: Estimator, preprocessor: Preprocessor = None):
class ImageClassifier:
def __init__(self, estimator: Classifier, preprocessor: Preprocessor = None):
self.estimator = estimator
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
def predict(self, images: List[Image], batch_size=2):
def predict(self, images: List[Image], batch_size=16):
batches = chunk_iterable(images, chunk_size=batch_size)
return chain(*map(self.pipe, batches))

View File

@ -49,7 +49,7 @@ def make_prediction_server(predict_fn: Callable):
except KeyError:
raise
logger.debug("Running predictor on document...")
logger.debug("Running classifier on document...")
try:
predictions = process()
response = jsonify(predictions)

View File

@ -10,8 +10,8 @@ logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_predict(estimator, input_batch, expected_predictions):
predictions = estimator.predict(input_batch)
def test_predict(classifier, input_batch, expected_predictions):
predictions = classifier.predict(input_batch)
assert predictions == expected_predictions

View File

@ -4,22 +4,22 @@ import numpy as np
import pytest
from PIL import Image
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
from image_prediction.estimator.adapter.adapters.mock import DummyEstimator, EstimatorAdapterMock
from image_prediction.estimator.estimator import Estimator
from image_prediction.exceptions import UnknownEstimatorAdapter
from image_prediction.predictor.predictor import Predictor
@pytest.fixture
def predictor(estimator, monkeypatch, expected_predictions):
return Predictor(estimator)
def predictor(classifier, monkeypatch, expected_predictions):
return ImageClassifier(classifier)
@pytest.fixture
def estimator(estimator_adapter, classes):
estimator = Estimator(estimator_adapter, classes)
return estimator
def classifier(estimator_adapter, classes):
classifier = Classifier(estimator_adapter, classes)
return classifier
@pytest.fixture

View File

@ -4,8 +4,7 @@ from image_prediction.utils import chunk_iterable
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
# @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 4])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_predict(predictor, images, expected_predictions):
predictions = list(predictor.predict(images))
assert predictions == expected_predictions