refactoring; added predictor; mocking of predict function is broken: fixing next commit
This commit is contained in:
parent
6343229c1e
commit
0f9510906d
@ -10,6 +10,13 @@ logger = get_logger()
|
|||||||
|
|
||||||
class Estimator:
|
class Estimator:
|
||||||
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
|
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
|
||||||
|
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
|
||||||
|
an EstimatorAdapter must be implemented.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
estimator_adapter: adapter for a given estimator backend
|
||||||
|
classes: mapping from a numerical label to a human-readable label for classes
|
||||||
|
"""
|
||||||
self.__estimator_adapter = estimator_adapter
|
self.__estimator_adapter = estimator_adapter
|
||||||
self.__classes = classes
|
self.__classes = classes
|
||||||
|
|
||||||
|
|||||||
26
image_prediction/predictor/predictor.py
Normal file
26
image_prediction/predictor/predictor.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
from image_prediction.estimator.estimator import Estimator
|
||||||
|
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||||
|
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||||
|
from image_prediction.utils import chunk_iterable
|
||||||
|
|
||||||
|
|
||||||
|
class Predictor:
|
||||||
|
def __init__(self, estimator: Estimator, preprocessor: Preprocessor = None):
|
||||||
|
self.estimator = estimator
|
||||||
|
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
|
||||||
|
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
|
||||||
|
|
||||||
|
def predict_images(self, images: List[Image], batch_size=4):
|
||||||
|
batches = chunk_iterable(images, chunk_size=batch_size)
|
||||||
|
batches = list(batches)
|
||||||
|
print(list(map(len, batches)))
|
||||||
|
for batch in batches:
|
||||||
|
print(len(batch))
|
||||||
|
print(self.pipe(batch))
|
||||||
|
|
||||||
|
return chain(*map(self.pipe, batches))
|
||||||
@ -16,27 +16,28 @@ def predictor(estimator):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def estimator(estimator_adapter, classes):
|
def estimator(estimator_adapter, classes):
|
||||||
service_estimator = Estimator(estimator_adapter, classes)
|
estimator = Estimator(estimator_adapter, classes)
|
||||||
return service_estimator
|
return estimator
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
|
def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
|
||||||
if estimator_type == "mock":
|
if estimator_type == "mock":
|
||||||
estimator = EstimatorAdapterMock(DummyEstimator())
|
estimator_adapter = EstimatorAdapterMock(DummyEstimator())
|
||||||
elif estimator_type == "keras":
|
elif estimator_type == "keras":
|
||||||
estimator = KerasEstimatorAdapter(keras_model)
|
estimator_adapter = KerasEstimatorAdapter(keras_model)
|
||||||
else:
|
else:
|
||||||
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
||||||
|
|
||||||
def mock_predict(batch):
|
def mock_predict(batch):
|
||||||
|
# assert len(batch) == len(output_batch)
|
||||||
_predict(batch)
|
_predict(batch)
|
||||||
return output_batch
|
return output_batch
|
||||||
|
|
||||||
_predict = estimator.predict
|
_predict = estimator_adapter.predict
|
||||||
monkeypatch.setattr(estimator, "predict", mock_predict)
|
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
|
||||||
|
|
||||||
return estimator
|
return estimator_adapter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -62,11 +63,6 @@ def keras_model(input_size):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def batch_size():
|
|
||||||
return 4
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def images(input_batch):
|
def images(input_batch):
|
||||||
return list(map(array_to_image, input_batch))
|
return list(map(array_to_image, input_batch))
|
||||||
|
|||||||
@ -8,13 +8,14 @@ logger = get_logger()
|
|||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||||
def test_predict(estimator, input_batch, expected_predictions):
|
def test_predict(estimator, input_batch, expected_predictions):
|
||||||
predictions = estimator.predict(input_batch)
|
predictions = estimator.predict(input_batch)
|
||||||
assert predictions == expected_predictions
|
assert predictions == expected_predictions
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||||
def test_batch_format(input_batch):
|
def test_batch_format(input_batch):
|
||||||
|
|
||||||
def channels_are_last(input_batch):
|
def channels_are_last(input_batch):
|
||||||
|
|||||||
@ -3,11 +3,11 @@ import pytest
|
|||||||
from image_prediction.utils import chunk_iterable
|
from image_prediction.utils import chunk_iterable
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("estimator_type", ["mock", "keras"], scope="session")
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||||
# @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64], scope="session")
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||||
# def test_predict(predictor, images, expected_predictions):
|
def test_predict(predictor, images, expected_predictions):
|
||||||
# predictions = list(predictor.predict_images(images))
|
predictions = list(predictor.predict_images(images))
|
||||||
# assert predictions == expected_predictions
|
assert predictions == expected_predictions
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_iterable_exact_split():
|
def test_chunk_iterable_exact_split():
|
||||||
|
|||||||
@ -1,29 +1,37 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from image_prediction.estimator.preprocessor.preprocessors.tensor_conversion import BasicPreprocessor
|
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor
|
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||||
def image_conversion_is_correct(image):
|
def image_conversion_is_correct(image):
|
||||||
tensor = image_to_normalized_tensor(image)
|
tensor = image_to_normalized_tensor(image)
|
||||||
image_re = Image.fromarray(np.uint8(tensor * 255), mode="RGB")
|
image_re = Image.fromarray(np.uint8(tensor * 255), mode="RGB")
|
||||||
return image == image_re and tensor.ndim == 3
|
return image == image_re and tensor.ndim == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||||
def images_conversion_is_correct(images, tensor):
|
def images_conversion_is_correct(images, tensor):
|
||||||
|
if not (images or tensor):
|
||||||
|
return True
|
||||||
return all([isinstance(tensor, np.ndarray), tensor.ndim == 4, tensor.shape[0] == len(images)])
|
return all([isinstance(tensor, np.ndarray), tensor.ndim == 4, tensor.shape[0] == len(images)])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||||
def test_image_to_tensor(images):
|
def test_image_to_tensor(images):
|
||||||
assert all(map(image_conversion_is_correct, images))
|
assert all(map(image_conversion_is_correct, images))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||||
def test_images_to_batch_tensor(images):
|
def test_images_to_batch_tensor(images):
|
||||||
tensor = images_to_batch_tensor(images)
|
tensor = images_to_batch_tensor(images)
|
||||||
assert images_conversion_is_correct(images, tensor)
|
assert images_conversion_is_correct(images, tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 4, 6], scope="session")
|
||||||
def test_basic_preprocessor(images):
|
def test_basic_preprocessor(images):
|
||||||
tensor = BasicPreprocessor()(images)
|
tensor = BasicPreprocessor()(images)
|
||||||
assert images_conversion_is_correct(images, tensor)
|
assert images_conversion_is_correct(images, tensor)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user