17 lines
659 B
Python
17 lines
659 B
Python
from operator import itemgetter
|
|
|
|
import pytest
|
|
|
|
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
|
|
|
|
|
@pytest.mark.parametrize("extractor_type", ["mock"])
|
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
|
@pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64])
|
|
def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions):
|
|
extractor_classifier = ExtractorClassifier(image_extractor, image_classifier)
|
|
results = list(extractor_classifier(images))
|
|
print(results)
|
|
labels = list(map(itemgetter("label"), results))
|
|
assert labels == expected_predictions
|