2022-03-25 14:46:04 +01:00

21 lines
652 B
Python

from typing import Mapping, List
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.utils import get_logger
logger = get_logger()
class ServiceEstimator:
def __init__(self, estimator_adapter: EstimatorAdapter, classes: Mapping[int, str]):
self.__estimator_adapter = estimator_adapter
self.__classes = classes
def predict(self, batch) -> List[str]:
if batch.shape[0] == 0:
logger.warning("ServiceEstimator received empty batch")
return []
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]