updated classifier test for label mappers
This commit is contained in:
parent
8bccec277f
commit
1c6f5749dd
@ -1,63 +1,33 @@
|
||||
from operator import itemgetter
|
||||
from typing import Mapping, List, Union, Tuple
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from funcy import rcompose
|
||||
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||
from image_prediction.label_mapper.mapper import LabelMapper
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Classifier:
|
||||
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
|
||||
def __init__(self, estimator_adapter: EstimatorAdapter, label_mapper: LabelMapper):
|
||||
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
|
||||
an EstimatorAdapter must be implemented.
|
||||
|
||||
Args:
|
||||
estimator_adapter: adapter for a given estimator backend
|
||||
classes: mapping from a numerical label to a human-readable label for classes
|
||||
"""
|
||||
self.__estimator_adapter = estimator_adapter
|
||||
self._classes = classes
|
||||
|
||||
def __validate_array_prediction_format(self, prediction):
|
||||
if not len(prediction) == len(self._classes):
|
||||
raise UnexpectedLabelFormat(
|
||||
f"Received fewer probabilities ({len(prediction)}) than classes were specified ({len(self._classes)})."
|
||||
)
|
||||
|
||||
def __validate_int_prediction_format(self, prediction):
|
||||
if not 0 <= prediction <= len(self._classes):
|
||||
raise UnexpectedLabelFormat(
|
||||
f"Received class index '{prediction}' as prediction that has no associated class label."
|
||||
)
|
||||
|
||||
def __format_array_prediction_format(self, prediction):
|
||||
cls2prob = dict(sorted(zip(self._classes, prediction), key=itemgetter(1), reverse=True))
|
||||
most_likely = [*cls2prob][0]
|
||||
return {"label": most_likely, "probabilities": cls2prob}
|
||||
|
||||
def __format_prediction(self, prediction):
|
||||
if isinstance(prediction, int):
|
||||
self.__validate_int_prediction_format(prediction)
|
||||
return self._classes[prediction]
|
||||
|
||||
elif isinstance(prediction, np.ndarray):
|
||||
self.__validate_array_prediction_format(prediction)
|
||||
return self.__format_array_prediction_format(prediction)
|
||||
|
||||
else:
|
||||
return prediction
|
||||
self.__label_mapper = label_mapper
|
||||
self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper)
|
||||
|
||||
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
|
||||
|
||||
if not isinstance(batch, tuple) and batch.shape[0] == 0:
|
||||
return []
|
||||
|
||||
return list(map(self.__format_prediction, self.__estimator_adapter.predict(batch)))
|
||||
return list(self.__pipe(batch))
|
||||
|
||||
def __call__(self, batch: np.array) -> List[str]:
|
||||
return self.predict(batch)
|
||||
|
||||
@ -4,3 +4,6 @@ class EstimatorAdapter:
|
||||
|
||||
def predict(self, batch):
|
||||
return self.estimator(batch)
|
||||
|
||||
def __call__(self, batch):
|
||||
return self.predict(batch)
|
||||
|
||||
@ -18,6 +18,7 @@ from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
||||
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
@ -42,8 +43,8 @@ def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def classifier(estimator_adapter, classes):
|
||||
classifier = Classifier(estimator_adapter, classes)
|
||||
def classifier(estimator_adapter, label_mapper):
|
||||
classifier = Classifier(estimator_adapter, label_mapper)
|
||||
return classifier
|
||||
|
||||
|
||||
@ -56,6 +57,11 @@ class EstimatorMock:
|
||||
return self.predict(batch)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def label_mapper(classes):
|
||||
return IndexMapper(classes)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def estimator_adapter(estimator_type, keras_model, output_batch_generator, monkeypatch):
|
||||
if estimator_type == "mock":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user