fixed batching issue in prediction monkey patch by introducinbg an output generator, that yields the expected predictions
This commit is contained in:
parent
0f9510906d
commit
5c5d132d7f
@ -8,4 +8,4 @@ class KerasEstimatorAdapter(EstimatorAdapter):
|
||||
super().__init__(estimator)
|
||||
|
||||
def predict(self, batch: np.array):
|
||||
self.estimator.predict(batch)
|
||||
return self.estimator.predict(batch)
|
||||
|
||||
@ -3,8 +3,8 @@ from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
|
||||
class DummyEstimator:
|
||||
@staticmethod
|
||||
def predict(_):
|
||||
return True
|
||||
def predict(batch):
|
||||
return [None for _ in batch]
|
||||
|
||||
|
||||
class EstimatorAdapterMock(EstimatorAdapter):
|
||||
|
||||
@ -15,12 +15,6 @@ class Predictor:
|
||||
self.preprocessor = preprocessor if preprocessor else BasicPreprocessor()
|
||||
self.pipe = lambda batch: self.estimator(self.preprocessor(batch))
|
||||
|
||||
def predict_images(self, images: List[Image], batch_size=4):
|
||||
def predict(self, images: List[Image], batch_size=2):
|
||||
batches = chunk_iterable(images, chunk_size=batch_size)
|
||||
batches = list(batches)
|
||||
print(list(map(len, batches)))
|
||||
for batch in batches:
|
||||
print(len(batch))
|
||||
print(self.pipe(batch))
|
||||
|
||||
return chain(*map(self.pipe, batches))
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
@ -10,7 +12,7 @@ from image_prediction.predictor.predictor import Predictor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def predictor(estimator):
|
||||
def predictor(estimator, monkeypatch, expected_predictions):
|
||||
return Predictor(estimator)
|
||||
|
||||
|
||||
@ -21,7 +23,14 @@ def estimator(estimator_adapter, classes):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
|
||||
def input_output_mapper(input_batch, classes):
|
||||
"""Mocks the internal, real estimator of an EstimatorAdapter object."""
|
||||
outputs = random.choices(range(len(classes)), k=len(input_batch))
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
|
||||
if estimator_type == "mock":
|
||||
estimator_adapter = EstimatorAdapterMock(DummyEstimator())
|
||||
elif estimator_type == "keras":
|
||||
@ -30,9 +39,9 @@ def estimator_adapter(estimator_type, keras_model, output_batch, monkeypatch):
|
||||
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
||||
|
||||
def mock_predict(batch):
|
||||
# assert len(batch) == len(output_batch)
|
||||
_predict(batch)
|
||||
return output_batch
|
||||
# Run real predict function to test for mechanical issues, but return externally defined
|
||||
# predictions to test the callers of the estimator adapter against the expected predictions
|
||||
return [next(output_batch_generator) for _ in _predict(batch)]
|
||||
|
||||
_predict = estimator_adapter.predict
|
||||
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
|
||||
@ -90,8 +99,13 @@ def expected_predictions(output_batch, classes):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_batch(batch_size, classes):
|
||||
return np.random.randint(low=0, high=len(classes), size=batch_size)
|
||||
def output_batch(input_batch, classes):
|
||||
return random.choices(range(len(classes)), k=len(input_batch))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_batch_generator(output_batch):
|
||||
return iter(output_batch)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -4,9 +4,10 @@ from image_prediction.utils import chunk_iterable
|
||||
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
# @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 4])
|
||||
def test_predict(predictor, images, expected_predictions):
|
||||
predictions = list(predictor.predict_images(images))
|
||||
predictions = list(predictor.predict(images))
|
||||
assert predictions == expected_predictions
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ def image_conversion_is_correct(image):
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
||||
def images_conversion_is_correct(images, tensor):
|
||||
if not (images or tensor):
|
||||
if not (images or tensor.size > 0):
|
||||
return True
|
||||
return all([isinstance(tensor, np.ndarray), tensor.ndim == 4, tensor.shape[0] == len(images)])
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user