refactoring; added predictor; mocking of predict function is broken: fixing next commit
This commit is contained in:
parent
6343229c1e
commit
0f9510906d
@ -10,6 +10,13 @@ logger = get_logger()
|
||||
|
||||
class Estimator:
|
||||
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
|
||||
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
|
||||
an EstimatorAdapter must be implemented.
|
||||
|
||||
Args:
|
||||
estimator_adapter: adapter for a given estimator backend
|
||||
classes: mapping from a numerical label to a human-readable label for classes
|
||||
"""
|
||||
self.__estimator_adapter = estimator_adapter
|
||||
self.__classes = classes
|
||||
|
||||
|
||||
26
image_prediction/predictor/predictor.py
Normal file
26
image_prediction/predictor/predictor.py
Normal file
@ -0,0 +1,26 @@
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
from PIL.Image import Image
|
||||
|
||||
from image_prediction.estimator.estimator import Estimator
|
||||
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.utils import chunk_iterable
|
||||
|
||||
|
||||
class Predictor:
|
||||
def __init__(self, estimator: Estimator, preprocessor: Preprocessor = None):
|
||||
self.estimator = estimator
|
||||
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
|
||||
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
|
||||
|
||||
def predict_images(self, images: List[Image], batch_size=4):
|
||||
batches = chunk_iterable(images, chunk_size=batch_size)
|
||||
batches = list(batches)
|
||||
print(list(map(len, batches)))
|
||||
for batch in batches:
|
||||
print(len(batch))
|
||||
print(self.pipe(batch))
|
||||
|
||||
return chain(*map(self.pipe, batches))
|
||||
@ -16,27 +16,28 @@ def predictor(estimator):
|
||||
|
||||
@pytest.fixture
|
||||
def estimator(estimator_adapter, classes):
|
||||
service_estimator = Estimator(estimator_adapter, classes)
|
||||
return service_estimator
|
||||
estimator = Estimator(estimator_adapter, classes)
|
||||
return estimator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
|
||||
if estimator_type == "mock":
|
||||
estimator = EstimatorAdapterMock(DummyEstimator())
|
||||
estimator_adapter = EstimatorAdapterMock(DummyEstimator())
|
||||
elif estimator_type == "keras":
|
||||
estimator = KerasEstimatorAdapter(keras_model)
|
||||
estimator_adapter = KerasEstimatorAdapter(keras_model)
|
||||
else:
|
||||
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
||||
|
||||
def mock_predict(batch):
|
||||
# assert len(batch) == len(output_batch)
|
||||
_predict(batch)
|
||||
return output_batch
|
||||
|
||||
_predict = estimator.predict
|
||||
monkeypatch.setattr(estimator, "predict", mock_predict)
|
||||
_predict = estimator_adapter.predict
|
||||
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
|
||||
|
||||
return estimator
|
||||
return estimator_adapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -62,11 +63,6 @@ def keras_model(input_size):
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_size():
|
||||
return 4
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def images(input_batch):
|
||||
return list(map(array_to_image, input_batch))
|
||||
|
||||
@ -8,13 +8,14 @@ logger = get_logger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def test_predict(estimator, input_batch, expected_predictions):
|
||||
predictions = estimator.predict(input_batch)
|
||||
assert predictions == expected_predictions
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def test_batch_format(input_batch):
|
||||
|
||||
def channels_are_last(input_batch):
|
||||
|
||||
@ -3,11 +3,11 @@ import pytest
|
||||
from image_prediction.utils import chunk_iterable
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
|
||||
# @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
||||
# def test_predict(predictor, images, expected_predictions):
|
||||
# predictions = list(predictor.predict_images(images))
|
||||
# assert predictions == expected_predictions
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def test_predict(predictor, images, expected_predictions):
|
||||
predictions = list(predictor.predict_images(images))
|
||||
assert predictions == expected_predictions
|
||||
|
||||
|
||||
def test_chunk_iterable_exact_split():
|
||||
|
||||
@ -1,29 +1,37 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from image_prediction.estimator.preprocessor.preprocessors.tensor_conversion import BasicPreprocessor
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def image_conversion_is_correct(image):
|
||||
tensor = image_to_normalized_tensor(image)
|
||||
image_re = Image.fromarray(np.uint8(tensor * 255), mode="RGB")
|
||||
return image == image_re and tensor.ndim == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def images_conversion_is_correct(images, tensor):
|
||||
if not (images or tensor):
|
||||
return True
|
||||
return all([isinstance(tensor, np.ndarray), tensor.ndim == 4, tensor.shape[0] == len(images)])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def test_image_to_tensor(images):
|
||||
assert all(map(image_conversion_is_correct, images))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def test_images_to_batch_tensor(images):
|
||||
tensor = images_to_batch_tensor(images)
|
||||
assert images_conversion_is_correct(images, tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 4, 6], scope="session")
|
||||
def test_basic_preprocessor(images):
|
||||
tensor = BasicPreprocessor()(images)
|
||||
assert images_conversion_is_correct(images, tensor)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user