2022-08-30 13:20:04 +02:00

21 lines
689 B
Python

from image_prediction.utils import get_logger
logger = get_logger()
class PredictionModelHandle:
"""Simplifies usage of ModelHandle instances for prediction purposes."""
def __init__(self, model_handle):
self.__prep_images = model_handle.prep_images
self.__predict = model_handle.model.predict
def predict(self, *args, **kwargs):
tensor, valid_mask = self.__prep_images(*args, **kwargs)
predictions = self.__predict(tensor)
return [p if v else None for p, v in zip(predictions, valid_mask)]
def __call__(self, *args, **kwargs):
logger.debug("PredictionModelHandle.predict")
return self.predict(*args, **kwargs)