Julius Unverfehrt 265c61df1a RED-5107 add handling for 'broken' images: broken image parts are
replaced by blank images in the stitching process and completly broken
images are also replaced by blank images which are passed through and
are classified as 'other' with all_pased == False. This should be
changed in the future by introducing a new key to the response,
indicating that the image is not valid.
2022-08-30 12:30:23 +02:00

21 lines
699 B
Python

from image_prediction.utils import get_logger
logger = get_logger()
class PredictionModelHandle:
"""Simplifies usage of ModelHandle instances for prediction purposes."""
def __init__(self, model_handle):
self.__prep_images = model_handle.prep_images
self.__predict = model_handle.model.predict
def predict(self, *args, **kwargs):
tensor, valid_mask = self.__prep_images(*args, **kwargs)
predictions = self.__predict(tensor, **kwargs)
return [p if v else None for p, v in zip(predictions, valid_mask)]
def __call__(self, *args, **kwargs):
logger.debug("PredictionModelHandle.predict")
return self.predict(*args, **kwargs)