diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index f9383a1..fa53b5d 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -1,5 +1,5 @@ import os -from functools import partial +from functools import partial, lru_cache from itertools import chain, tee from typing import Iterable @@ -20,6 +20,7 @@ from image_prediction.utils.generic import lift, starlift os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +@lru_cache() def load_pipeline(**kwargs): model_loader = get_mlflow_model_loader(MLRUNS_DIR) model_identifier = CONFIG.service.mlflow_run_id