from operator import itemgetter import pytest from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier @pytest.mark.parametrize("extractor_type", ["mock"]) @pytest.mark.parametrize("estimator_type", ["mock", "keras"]) @pytest.mark.parametrize("batch_size", [0, 1, 2, 16, 32, 64]) def test_extractor_classifier(image_extractor, image_classifier, images, expected_predictions): extractor_classifier = ExtractorClassifier(image_extractor, image_classifier) results = extractor_classifier(images) labels = list(map(itemgetter("label"), results)) assert labels == expected_predictions