Matthias Bisping 9d58ae714f renaming
2022-03-27 17:55:01 +02:00

33 lines
1.1 KiB
Python

from typing import Mapping, List
import numpy as np
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.utils import get_logger
logger = get_logger()
class Classifier:
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
an EstimatorAdapter must be implemented.
Args:
estimator_adapter: adapter for a given estimator backend; expected to be a classifier that returns numeric
labels as predictions
classes: mapping from a numerical label to a human-readable label for classes
"""
self.__estimator_adapter = estimator_adapter
self.__classes = classes
def predict(self, batch: np.array) -> List[str]:
if batch.shape[0] == 0:
return []
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
def __call__(self, batch: np.array) -> List[str]:
return self.predict(batch)