from typing import Mapping, List import numpy as np from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.utils import get_logger logger = get_logger() class Classifier: def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]): """Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used an EstimatorAdapter must be implemented. Args: estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns numeric labels as predictions classes: mapping from a numerical label to a human-readable label for classes """ self.__estimator_adapter = estimator_adapter self.__classes = classes def predict(self, batch: np.array) -> List[str]: if batch.shape[0] == 0: return [] return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)] def __call__(self, batch: np.array) -> List[str]: return self.predict(batch)