From a192e05be2621d5e6ebf71b6f30d997801e709f8 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Fri, 6 Sep 2024 15:51:14 +0200 Subject: [PATCH] feat: parameterize image stiching tolerance Also sets image stitching tolerance default to one (pixel) and adds informative log of which settings are loaded when initializing the image classification pipeline. --- config/settings.toml | 1 + scripts/run_pipeline.py | 2 +- src/image_prediction/pipeline.py | 2 ++ src/serve.py | 2 +- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/config/settings.toml b/config/settings.toml index 0950aed..40d76a2 100644 --- a/config/settings.toml +++ b/config/settings.toml @@ -5,6 +5,7 @@ level = "INFO" # Print document processing progress to stdout verbose = false batch_size = 16 +image_stiching_tolerance = 1 # in pixels mlflow_run_id = "fabfb1f192c745369b88cab34471aba7" # These variables control filters that are applied to either images, image metadata or service_estimator predictions. diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index f12fb1d..613b49f 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -36,7 +36,7 @@ def process_pdf(pipeline, pdf_path, page_range=None): def main(args): - pipeline = load_pipeline(verbose=True, batch_size=CONFIG.service.batch_size) + pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size, tolerance=CONFIG.service.image_stiching_tolerance) if os.path.isfile(args.input): pdf_paths = [args.input] diff --git a/src/image_prediction/pipeline.py b/src/image_prediction/pipeline.py index 2bff17a..4a8a62d 100644 --- a/src/image_prediction/pipeline.py +++ b/src/image_prediction/pipeline.py @@ -3,6 +3,7 @@ from functools import lru_cache, partial from itertools import chain, tee from funcy import rcompose, first, compose, second, chunks, identity, rpartial +from kn_utils.logging import logger from tqdm import tqdm from image_prediction.config import CONFIG @@ -21,6 +22,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @lru_cache(maxsize=None) def load_pipeline(**kwargs): + logger.info(f"Loading pipeline with kwargs: {kwargs}") model_loader = get_mlflow_model_loader(MLRUNS_DIR) model_identifier = CONFIG.service.mlflow_run_id diff --git a/src/serve.py b/src/serve.py index 848175c..506b818 100644 --- a/src/serve.py +++ b/src/serve.py @@ -18,7 +18,7 @@ logger.reconfigure(sink=stdout, level=CONFIG.logging.level) # FIXME: Find more fine-grained solution or if the problem occurs persistently for python services, @wrap_in_process def process_data(data: bytes, _message: dict) -> list: - pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size) + pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size, tolerance=CONFIG.service.image_stiching_tolerance) return list(pipeline(data))