renaming
This commit is contained in:
parent
6835394d30
commit
99d8e921db
@ -9,9 +9,9 @@ logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
def test_predict(classifier, input_batch, expected_predictions):
|
||||
def test_predict(classifier, input_batch, expected_batch_string_labels):
|
||||
predictions = classifier.predict(input_batch)
|
||||
assert predictions == expected_predictions
|
||||
assert predictions == expected_batch_string_labels
|
||||
|
||||
|
||||
def test_batch_format(input_batch):
|
||||
|
||||
@ -37,7 +37,7 @@ def image_extractor(extractor_type):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_classifier(classifier, monkeypatch, expected_predictions):
|
||||
def image_classifier(classifier, monkeypatch, expected_batch_string_labels):
|
||||
return ImageClassifier(classifier, preprocessor=BasicPreprocessor())
|
||||
|
||||
|
||||
@ -129,18 +129,18 @@ def input_size(request):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_predictions(output_batch, classes):
|
||||
return map_labels(output_batch, classes)
|
||||
def expected_batch_string_labels(expected_batch_numeric_labels, classes):
|
||||
return map_labels(expected_batch_numeric_labels, classes)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_batch(input_batch, classes):
|
||||
def expected_batch_numeric_labels(input_batch, classes):
|
||||
return random.choices(range(len(classes)), k=len(input_batch))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def output_batch_generator(output_batch):
|
||||
return iter(output_batch)
|
||||
def output_batch_generator(expected_batch_numeric_labels):
|
||||
return iter(expected_batch_numeric_labels)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -7,8 +7,8 @@ from image_prediction.extractor_classifier.extractor_classifier import Extractor
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions):
|
||||
def test_extractor_classifier(image_extractor, image_classifier, images, expected_batch_string_labels):
|
||||
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
|
||||
results = extractor_classifier(images)
|
||||
labels = list(map(itemgetter("prediction"), results))
|
||||
assert labels == expected_predictions
|
||||
assert labels == expected_batch_string_labels
|
||||
|
||||
@ -4,9 +4,9 @@ from image_prediction.utils import chunk_iterable
|
||||
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
def test_predict(image_classifier, images, expected_predictions):
|
||||
def test_predict(image_classifier, images, expected_batch_string_labels):
|
||||
predictions = list(image_classifier.predict(images))
|
||||
assert predictions == expected_predictions
|
||||
assert predictions == expected_batch_string_labels
|
||||
|
||||
|
||||
def test_chunk_iterable_exact_split():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user