21 lines
689 B
Python
21 lines
689 B
Python
from image_prediction.utils import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class PredictionModelHandle:
|
|
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
|
|
|
def __init__(self, model_handle):
|
|
self.__prep_images = model_handle.prep_images
|
|
self.__predict = model_handle.model.predict
|
|
|
|
def predict(self, *args, **kwargs):
|
|
tensor, valid_mask = self.__prep_images(*args, **kwargs)
|
|
predictions = self.__predict(tensor)
|
|
return [p if v else None for p, v in zip(predictions, valid_mask)]
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
logger.debug("PredictionModelHandle.predict")
|
|
return self.predict(*args, **kwargs)
|