This commit is contained in:
Matthias Bisping 2022-03-27 22:59:28 +02:00
parent 334dc79f7e
commit 4c939464b0
3 changed files with 6 additions and 6 deletions

View File

@ -1,7 +1,7 @@
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
class DummyEstimator:
class EstimatorMock:
@staticmethod
def predict(batch):
return [None for _ in batch]

View File

@ -7,12 +7,12 @@ from PIL import Image
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapters.keras import KerasEstimatorAdapter
from image_prediction.estimator.adapter.adapters.mock import DummyEstimator, EstimatorAdapterMock
from image_prediction.estimator.adapter.adapters.mock import EstimatorMock, EstimatorAdapterMock
from image_prediction.exceptions import UnknownEstimatorAdapter
@pytest.fixture
def predictor(classifier, monkeypatch, expected_predictions):
def image_classifier(classifier, monkeypatch, expected_predictions):
return ImageClassifier(classifier)
@ -25,7 +25,7 @@ def classifier(estimator_adapter, classes):
@pytest.fixture
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
if estimator_type == "mock":
estimator_adapter = EstimatorAdapterMock(DummyEstimator())
estimator_adapter = EstimatorAdapterMock(EstimatorMock())
elif estimator_type == "keras":
estimator_adapter = KerasEstimatorAdapter(keras_model)
else:

View File

@ -5,8 +5,8 @@ from image_prediction.utils import chunk_iterable
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
def test_predict(predictor, images, expected_predictions):
predictions = list(predictor.predict(images))
def test_predict(image_classifier, images, expected_predictions):
predictions = list(image_classifier.predict(images))
assert predictions == expected_predictions