From 5c5d132d7fdae1e97a05810b6ab14751934e1c71 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Sun, 27 Mar 2022 01:13:21 +0100 Subject: [PATCH] fixed batching issue in prediction monkey patch by introducinbg an output generator, that yields the expected predictions --- .../estimator/adapter/adapters/keras.py | 2 +- .../estimator/adapter/adapters/mock.py | 4 +-- image_prediction/predictor/predictor.py | 8 +----- test/unit_tests/conftest.py | 28 ++++++++++++++----- test/unit_tests/predictor_test.py | 5 ++-- test/unit_tests/preprocessor_test.py | 2 +- 6 files changed, 29 insertions(+), 20 deletions(-) diff --git a/image_prediction/estimator/adapter/adapters/keras.py b/image_prediction/estimator/adapter/adapters/keras.py index 294ddfd..d54d925 100644 --- a/image_prediction/estimator/adapter/adapters/keras.py +++ b/image_prediction/estimator/adapter/adapters/keras.py @@ -8,4 +8,4 @@ class KerasEstimatorAdapter(EstimatorAdapter): super().__init__(estimator) def predict(self, batch: np.array): - self.estimator.predict(batch) + return self.estimator.predict(batch) diff --git a/image_prediction/estimator/adapter/adapters/mock.py b/image_prediction/estimator/adapter/adapters/mock.py index a2d6d6d..a8d5b6e 100644 --- a/image_prediction/estimator/adapter/adapters/mock.py +++ b/image_prediction/estimator/adapter/adapters/mock.py @@ -3,8 +3,8 @@ from image_prediction.estimator.adapter.adapter import EstimatorAdapter class DummyEstimator: @staticmethod - def predict(_): - return True + def predict(batch): + return [None for _ in batch] class EstimatorAdapterMock(EstimatorAdapter): diff --git a/image_prediction/predictor/predictor.py b/image_prediction/predictor/predictor.py index a9eda45..af14760 100644 --- a/image_prediction/predictor/predictor.py +++ b/image_prediction/predictor/predictor.py @@ -15,12 +15,6 @@ class Predictor: self.preprocessor = preprocessor if preprocessor else BasicPreprocessor() self.pipe = lambda batch: self.estimator(self.preprocessor(batch)) - def predict_images(self, images: List[Image], batch_size=4): + def predict(self, images: List[Image], batch_size=2): batches = chunk_iterable(images, chunk_size=batch_size) - batches = list(batches) - print(list(map(len, batches))) - for batch in batches: - print(len(batch)) - print(self.pipe(batch)) - return chain(*map(self.pipe, batches)) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index fd9f4dc..5de1f69 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -1,3 +1,5 @@ +import random + import numpy as np import pytest from PIL import Image @@ -10,7 +12,7 @@ from image_prediction.predictor.predictor import Predictor @pytest.fixture -def predictor(estimator): +def predictor(estimator, monkeypatch, expected_predictions): return Predictor(estimator) @@ -21,7 +23,14 @@ def estimator(estimator_adapter, classes): @pytest.fixture -def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch): +def input_output_mapper(input_batch, classes): + """Mocks the internal, real estimator of an EstimatorAdapter object.""" + outputs = random.choices(range(len(classes)), k=len(input_batch)) + return outputs + + +@pytest.fixture +def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch): if estimator_type == "mock": estimator_adapter = EstimatorAdapterMock(DummyEstimator()) elif estimator_type == "keras": @@ -30,9 +39,9 @@ def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch): raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.") def mock_predict(batch): - # assert len(batch) == len(output_batch) - _predict(batch) - return output_batch + # Run real predict function to test for mechanical issues, but return externally defined + # predictions to test the callers of the estimator adapter against the expected predictions + return [next(output_batch_generator) for _ in _predict(batch)] _predict = estimator_adapter.predict monkeypatch.setattr(estimator_adapter, "predict", mock_predict) @@ -90,8 +99,13 @@ def expected_predictions(output_batch, classes): @pytest.fixture -def output_batch(batch_size, classes): - return np.random.randint(low=0, high=len(classes), size=batch_size) +def output_batch(input_batch, classes): + return random.choices(range(len(classes)), k=len(input_batch)) + + +@pytest.fixture +def output_batch_generator(output_batch): + return iter(output_batch) @pytest.fixture diff --git a/test/unit_tests/predictor_test.py b/test/unit_tests/predictor_test.py index c1acb77..0a2441d 100644 --- a/test/unit_tests/predictor_test.py +++ b/test/unit_tests/predictor_test.py @@ -4,9 +4,10 @@ from image_prediction.utils import chunk_iterable @pytest.mark.parametrize("estimator_type", ["mock", "keras"]) -@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64]) +# @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64]) +@pytest.mark.parametrize("batch_size", [0, 1, 2, 4]) def test_predict(predictor, images, expected_predictions): - predictions = list(predictor.predict_images(images)) + predictions = list(predictor.predict(images)) assert predictions == expected_predictions diff --git a/test/unit_tests/preprocessor_test.py b/test/unit_tests/preprocessor_test.py index f2f2fd1..72e8618 100644 --- a/test/unit_tests/preprocessor_test.py +++ b/test/unit_tests/preprocessor_test.py @@ -15,7 +15,7 @@ def image_conversion_is_correct(image): @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64]) def images_conversion_is_correct(images, tensor): - if not (images or tensor): + if not (images or tensor.size > 0): return True return all([isinstance(tensor, np.ndarray), tensor.ndim == 4, tensor.shape[0] == len(images)])