updated classifier test for label mappers

This commit is contained in:
Matthias Bisping 2022-03-30 16:04:13 +02:00
parent 8bccec277f
commit 1c6f5749dd
3 changed files with 18 additions and 39 deletions

View File

@ -1,63 +1,33 @@
from operator import itemgetter
from typing import Mapping, List, Union, Tuple
from typing import List, Union, Tuple
import numpy as np
from PIL.Image import Image
from funcy import rcompose
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.exceptions import UnexpectedLabelFormat
from image_prediction.label_mapper.mapper import LabelMapper
from image_prediction.utils import get_logger
logger = get_logger()
class Classifier:
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
def __init__(self, estimator_adapter: EstimatorAdapter, label_mapper: LabelMapper):
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
an EstimatorAdapter must be implemented.
Args:
estimator_adapter: adapter for a given estimator backend
classes: mapping from a numerical label to a human-readable label for classes
"""
self.__estimator_adapter = estimator_adapter
self._classes = classes
def __validate_array_prediction_format(self, prediction):
if not len(prediction) == len(self._classes):
raise UnexpectedLabelFormat(
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)})."
)
def __validate_int_prediction_format(self, prediction):
if not 0 <= prediction <= len(self._classes):
raise UnexpectedLabelFormat(
f"Received class index '{prediction}' as prediction that has no associated class label."
)
def __format_array_prediction_format(self, prediction):
cls2prob = dict(sorted(zip(self._classes, prediction), key=itemgetter(1), reverse=True))
most_likely = [*cls2prob][0]
return {"label": most_likely, "probabilities": cls2prob}
def __format_prediction(self, prediction):
if isinstance(prediction, int):
self.__validate_int_prediction_format(prediction)
return self._classes[prediction]
elif isinstance(prediction, np.ndarray):
self.__validate_array_prediction_format(prediction)
return self.__format_array_prediction_format(prediction)
else:
return prediction
self.__label_mapper = label_mapper
self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper)
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
if not isinstance(batch, tuple) and batch.shape[0] == 0:
return []
return list(map(self.__format_prediction, self.__estimator_adapter.predict(batch)))
return list(self.__pipe(batch))
def __call__(self, batch: np.array) -> List[str]:
return self.predict(batch)

View File

@ -4,3 +4,6 @@ class EstimatorAdapter:
def predict(self, batch):
return self.estimator(batch)
def __call__(self, batch):
return self.predict(batch)

View File

@ -18,6 +18,7 @@ from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
from image_prediction.info import Info
from image_prediction.label_mapper.mappers.numeric import IndexMapper
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
from image_prediction.model_loader.loader import ModelLoader
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
@ -42,8 +43,8 @@ def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels):
@pytest.fixture
def classifier(estimator_adapter, classes):
classifier = Classifier(estimator_adapter, classes)
def classifier(estimator_adapter, label_mapper):
classifier = Classifier(estimator_adapter, label_mapper)
return classifier
@ -56,6 +57,11 @@ class EstimatorMock:
return self.predict(batch)
@pytest.fixture()
def label_mapper(classes):
return IndexMapper(classes)
@pytest.fixture
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
if estimator_type == "mock":