renaming
This commit is contained in:
parent
0f811bdc56
commit
9d58ae714f
@ -8,13 +8,14 @@ from image_prediction.utils import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Estimator:
|
||||
class Classifier:
|
||||
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
|
||||
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
|
||||
an EstimatorAdapter must be implemented.
|
||||
|
||||
Args:
|
||||
estimator_adapter: adapter for a given estimator backend
|
||||
estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns numeric
|
||||
labels as predictions
|
||||
classes: mapping from a numerical label to a human-readable label for classes
|
||||
"""
|
||||
self.__estimator_adapter = estimator_adapter
|
||||
@ -3,18 +3,18 @@ from typing import List
|
||||
|
||||
from PIL.Image import Image
|
||||
|
||||
from image_prediction.estimator.estimator import Estimator
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.utils import chunk_iterable
|
||||
|
||||
|
||||
class Predictor:
|
||||
def __init__(self, estimator: Estimator, preprocessor: Preprocessor = None):
|
||||
class ImageClassifier:
|
||||
def __init__(self, estimator: Classifier, preprocessor: Preprocessor = None):
|
||||
self.estimator = estimator
|
||||
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
|
||||
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
|
||||
|
||||
def predict(self, images: List[Image], batch_size=2):
|
||||
def predict(self, images: List[Image], batch_size=16):
|
||||
batches = chunk_iterable(images, chunk_size=batch_size)
|
||||
return chain(*map(self.pipe, batches))
|
||||
@ -49,7 +49,7 @@ def make_prediction_server(predict_fn: Callable):
|
||||
except KeyError:
|
||||
raise
|
||||
|
||||
logger.debug("Running predictor on document...")
|
||||
logger.debug("Running classifier on document...")
|
||||
try:
|
||||
predictions = process()
|
||||
response = jsonify(predictions)
|
||||
|
||||
@ -10,8 +10,8 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def test_predict(estimator, input_batch, expected_predictions):
|
||||
predictions = estimator.predict(input_batch)
|
||||
def test_predict(classifier, input_batch, expected_predictions):
|
||||
predictions = classifier.predict(input_batch)
|
||||
assert predictions == expected_predictions
|
||||
|
||||
|
||||
@ -4,22 +4,22 @@ import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
|
||||
from image_prediction.estimator.adapter.adapters.mock import DummyEstimator, EstimatorAdapterMock
|
||||
from image_prediction.estimator.estimator import Estimator
|
||||
from image_prediction.exceptions import UnknownEstimatorAdapter
|
||||
from image_prediction.predictor.predictor import Predictor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictor(estimator, monkeypatch, expected_predictions):
|
||||
return Predictor(estimator)
|
||||
def predictor(classifier, monkeypatch, expected_predictions):
|
||||
return ImageClassifier(classifier)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def estimator(estimator_adapter, classes):
|
||||
estimator = Estimator(estimator_adapter, classes)
|
||||
return estimator
|
||||
def classifier(estimator_adapter, classes):
|
||||
classifier = Classifier(estimator_adapter, classes)
|
||||
return classifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -4,8 +4,7 @@ from image_prediction.utils import chunk_iterable
|
||||
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
# @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 4])
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def test_predict(predictor, images, expected_predictions):
|
||||
predictions = list(predictor.predict(images))
|
||||
assert predictions == expected_predictions
|
||||
Loading…
x
Reference in New Issue
Block a user