added type hint
This commit is contained in:
parent
03f269c2d7
commit
2e36a9d46d
@ -1,5 +1,7 @@
|
|||||||
from typing import Mapping, List
|
from typing import Mapping, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
@ -11,10 +13,10 @@ class ServiceEstimator:
|
|||||||
self.__estimator_adapter = estimator_adapter
|
self.__estimator_adapter = estimator_adapter
|
||||||
self.__classes = classes
|
self.__classes = classes
|
||||||
|
|
||||||
def predict(self, batch) -> List[str]:
|
def predict(self, batch: np.array) -> List[str]:
|
||||||
|
|
||||||
if batch.shape[0] == 0:
|
if batch.shape[0] == 0:
|
||||||
logger.warning("ServiceEstimator received empty batch")
|
logger.warning("ServiceEstimator received empty batch.")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
return [self.__classes[numeric_label] for numeric_label in self.__estimator_adapter.predict(batch)]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user