From 1c6f5749ddcd82330333e8bf3a727b637a58b6c8 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Wed, 30 Mar 2022 16:04:13 +0200 Subject: [PATCH] updated classifier test for label mappers --- image_prediction/classifier/classifier.py | 44 +++---------------- image_prediction/estimator/adapter/adapter.py | 3 ++ test/unit_tests/conftest.py | 10 ++++- 3 files changed, 18 insertions(+), 39 deletions(-) diff --git a/image_prediction/classifier/classifier.py b/image_prediction/classifier/classifier.py index 546c8f2..ddbd8c0 100644 --- a/image_prediction/classifier/classifier.py +++ b/image_prediction/classifier/classifier.py @@ -1,63 +1,33 @@ -from operator import itemgetter -from typing import Mapping, List, Union, Tuple +from typing import List, Union, Tuple import numpy as np from PIL.Image import Image +from funcy import rcompose from image_prediction.estimator.adapter.adapter import EstimatorAdapter -from image_prediction.exceptions import UnexpectedLabelFormat +from image_prediction.label_mapper.mapper import LabelMapper from image_prediction.utils import get_logger logger = get_logger() class Classifier: - def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]): + def __init__(self, estimator_adapter: EstimatorAdapter, label_mapper: LabelMapper): """Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used an EstimatorAdapter must be implemented. Args: estimator_adapter: adapter for a given estimator backend - classes: mapping from a numerical label to a human-readable label for classes """ self.__estimator_adapter = estimator_adapter - self._classes = classes - - def __validate_array_prediction_format(self, prediction): - if not len(prediction) == len(self._classes): - raise UnexpectedLabelFormat( - f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)})." - ) - - def __validate_int_prediction_format(self, prediction): - if not 0 <= prediction <= len(self._classes): - raise UnexpectedLabelFormat( - f"Received class index '{prediction}' as prediction that has no associated class label." - ) - - def __format_array_prediction_format(self, prediction): - cls2prob = dict(sorted(zip(self._classes, prediction), key=itemgetter(1), reverse=True)) - most_likely = [*cls2prob][0] - return {"label": most_likely, "probabilities": cls2prob} - - def __format_prediction(self, prediction): - if isinstance(prediction, int): - self.__validate_int_prediction_format(prediction) - return self._classes[prediction] - - elif isinstance(prediction, np.ndarray): - self.__validate_array_prediction_format(prediction) - return self.__format_array_prediction_format(prediction) - - else: - return prediction + self.__label_mapper = label_mapper + self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper) def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]: - if not isinstance(batch, tuple) and batch.shape[0] == 0: return [] - return list(map(self.__format_prediction, self.__estimator_adapter.predict(batch))) + return list(self.__pipe(batch)) def __call__(self, batch: np.array) -> List[str]: return self.predict(batch) diff --git a/image_prediction/estimator/adapter/adapter.py b/image_prediction/estimator/adapter/adapter.py index 692178e..7ae5bc6 100644 --- a/image_prediction/estimator/adapter/adapter.py +++ b/image_prediction/estimator/adapter/adapter.py @@ -4,3 +4,6 @@ class EstimatorAdapter: def predict(self, batch): return self.estimator(batch) + + def __call__(self, batch): + return self.predict(batch) diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index 557cec7..5fc55d6 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -18,6 +18,7 @@ from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.mock import ImageExtractorMock from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.info import Info +from image_prediction.label_mapper.mappers.numeric import IndexMapper from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock from image_prediction.model_loader.loader import ModelLoader from image_prediction.model_loader.loaders.mlflow import MlflowConnector @@ -42,8 +43,8 @@ def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels): @pytest.fixture -def classifier(estimator_adapter, classes): - classifier = Classifier(estimator_adapter, classes) +def classifier(estimator_adapter, label_mapper): + classifier = Classifier(estimator_adapter, label_mapper) return classifier @@ -56,6 +57,11 @@ class EstimatorMock: return self.predict(batch) +@pytest.fixture() +def label_mapper(classes): + return IndexMapper(classes) + + @pytest.fixture def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch): if estimator_type == "mock":