from typing import List, Union, Tuple import numpy as np from PIL.Image import Image from funcy import rcompose from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.label_mapper.mapper import LabelMapper from image_prediction.utils import get_logger logger = get_logger() class Classifier: def __init__(self, estimator_adapter: EstimatorAdapter, label_mapper: LabelMapper): """Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used an EstimatorAdapter must be implemented. Args: estimator_adapter: adapter for a given estimator backend """ self.__estimator_adapter = estimator_adapter self.__label_mapper = label_mapper self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper) def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]: if isinstance(batch, np.ndarray) and batch.shape[0] == 0: return [] return self.__pipe(batch) def __call__(self, batch: np.array) -> List[str]: logger.debug("Classifier.predict") return self.predict(batch)