31 lines
1.1 KiB
Python
31 lines
1.1 KiB
Python
from typing import List, Union, Tuple
|
|
|
|
import numpy as np
|
|
from PIL.Image import Image
|
|
from funcy import rcompose
|
|
|
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
|
from image_prediction.label_mapper.mapper import LabelMapper
|
|
|
|
|
|
class Classifier:
|
|
def __init__(self, estimator_adapter: EstimatorAdapter, label_mapper: LabelMapper):
|
|
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
|
|
an EstimatorAdapter must be implemented.
|
|
|
|
Args:
|
|
estimator_adapter: adapter for a given estimator backend
|
|
"""
|
|
self.__estimator_adapter = estimator_adapter
|
|
self.__label_mapper = label_mapper
|
|
self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper)
|
|
|
|
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
|
|
if not isinstance(batch, tuple) and batch.shape[0] == 0:
|
|
return []
|
|
|
|
return list(self.__pipe(batch))
|
|
|
|
def __call__(self, batch: np.array) -> List[str]:
|
|
return self.predict(batch)
|