25 lines
697 B
Python
25 lines
697 B
Python
from typing import Mapping, List
|
|
|
|
import numpy as np
|
|
|
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
|
from image_prediction.utils import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class Estimator:
|
|
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
|
|
self.__estimator_adapter = estimator_adapter
|
|
self.__classes = classes
|
|
|
|
def predict(self, batch: np.array) -> List[str]:
|
|
|
|
if batch.shape[0] == 0:
|
|
return []
|
|
|
|
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
|
|
|
def __call__(self, batch: np.array) -> List[str]:
|
|
return self.predict(batch)
|