from image_prediction.utils import get_logger logger = get_logger() class PredictionModelHandle: """Simplifies usage of ModelHandle instances for prediction purposes.""" def __init__(self, model_handle): self.__prep_images = model_handle.prep_images self.__predict = model_handle.model.predict def predict(self, *args, **kwargs): tensor, valid_mask = self.__prep_images(*args, **kwargs) predictions = self.__predict(tensor) return [p if v else None for p, v in zip(predictions, valid_mask)] def __call__(self, *args, **kwargs): logger.debug("PredictionModelHandle.predict") return self.predict(*args, **kwargs)