diff --git a/image_prediction/service_estimator/service_estimator.py b/image_prediction/service_estimator/service_estimator.py index 623ca01..1704fb1 100644 --- a/image_prediction/service_estimator/service_estimator.py +++ b/image_prediction/service_estimator/service_estimator.py @@ -1,5 +1,7 @@ from typing import Mapping, List +import numpy as np + from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.utils import get_logger @@ -11,10 +13,10 @@ class ServiceEstimator: self.__estimator_adapter = estimator_adapter self.__classes = classes - def predict(self, batch) -> List[str]: + def predict(self, batch: np.array) -> List[str]: if batch.shape[0] == 0: - logger.warning("ServiceEstimator received empty batch") + logger.warning("ServiceEstimator received empty batch.") return [] return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]