refactoring; added predictor; mocking of predict function is broken: fixing next commit

This commit is contained in:
Matthias Bisping 2022-03-26 21:18:42 +01:00
parent 6343229c1e
commit 0f9510906d
7 changed files with 58 additions and 20 deletions

View File

@ -10,6 +10,13 @@ logger = get_logger()
class Estimator:
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
an EstimatorAdapter must be implemented.
Args:
estimator_adapter: adapter for a given estimator backend
classes: mapping from a numerical label to a human-readable label for classes
"""
self.__estimator_adapter = estimator_adapter
self.__classes = classes

View File

@ -0,0 +1,26 @@
from itertools import chain
from typing import List
from PIL.Image import Image
from image_prediction.estimator.estimator import Estimator
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.utils import chunk_iterable
class Predictor:
def __init__(self, estimator: Estimator, preprocessor: Preprocessor = None):
self.estimator = estimator
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
def predict_images(self, images: List[Image], batch_size=4):
batches = chunk_iterable(images, chunk_size=batch_size)
batches = list(batches)
print(list(map(len, batches)))
for batch in batches:
print(len(batch))
print(self.pipe(batch))
return chain(*map(self.pipe, batches))

View File

@ -16,27 +16,28 @@ def predictor(estimator):
@pytest.fixture
def estimator(estimator_adapter, classes):
service_estimator = Estimator(estimator_adapter, classes)
return service_estimator
estimator = Estimator(estimator_adapter, classes)
return estimator
@pytest.fixture
def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
if estimator_type == "mock":
estimator = EstimatorAdapterMock(DummyEstimator())
estimator_adapter = EstimatorAdapterMock(DummyEstimator())
elif estimator_type == "keras":
estimator = KerasEstimatorAdapter(keras_model)
estimator_adapter = KerasEstimatorAdapter(keras_model)
else:
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
def mock_predict(batch):
# assert len(batch) == len(output_batch)
_predict(batch)
return output_batch
_predict = estimator.predict
monkeypatch.setattr(estimator, "predict", mock_predict)
_predict = estimator_adapter.predict
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
return estimator
return estimator_adapter
@pytest.fixture
@ -62,11 +63,6 @@ def keras_model(input_size):
return model
@pytest.fixture
def batch_size():
return 4
@pytest.fixture
def images(input_batch):
return list(map(array_to_image, input_batch))

View File

@ -8,13 +8,14 @@ logger = get_logger()
logger.setLevel(logging.DEBUG)
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_predict(estimator, input_batch, expected_predictions):
predictions = estimator.predict(input_batch)
assert predictions == expected_predictions
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_batch_format(input_batch):
def channels_are_last(input_batch):

View File

@ -3,11 +3,11 @@ import pytest
from image_prediction.utils import chunk_iterable
# @pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
# @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
# def test_predict(predictor, images, expected_predictions):
# predictions = list(predictor.predict_images(images))
# assert predictions == expected_predictions
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_predict(predictor, images, expected_predictions):
predictions = list(predictor.predict_images(images))
assert predictions == expected_predictions
def test_chunk_iterable_exact_split():

View File

@ -1,29 +1,37 @@
import numpy as np
import pytest
from PIL import Image
from image_prediction.estimator.preprocessor.preprocessors.tensor_conversion import BasicPreprocessor
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def image_conversion_is_correct(image):
tensor = image_to_normalized_tensor(image)
image_re = Image.fromarray(np.uint8(tensor * 255), mode="RGB")
return image == image_re and tensor.ndim == 3
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def images_conversion_is_correct(images, tensor):
if not (images or tensor):
return True
return all([isinstance(tensor, np.ndarray), tensor.ndim == 4, tensor.shape[0] == len(images)])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_image_to_tensor(images):
assert all(map(image_conversion_is_correct, images))
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_images_to_batch_tensor(images):
tensor = images_to_batch_tensor(images)
assert images_conversion_is_correct(images, tensor)
@pytest.mark.parametrize("batch_size", [0, 1, 2, 4, 6], scope="session")
def test_basic_preprocessor(images):
tensor = BasicPreprocessor()(images)
assert images_conversion_is_correct(images, tensor)