From 1fcec06e91aa53cd20cf54a41b18d6bae18480fe Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Tue, 19 Apr 2022 15:09:14 +0200 Subject: [PATCH] batch size and progressbar message forwarding --- config.yaml | 2 +- .../extractor_classifier/extractor_classifier.py | 15 ++++++++------- .../image_extractor/extractors/parsable.py | 10 ++++++++-- image_prediction/pipeline.py | 2 +- scripts/run_pipeline.py | 2 +- src/serve.py | 2 +- 6 files changed, 20 insertions(+), 13 deletions(-) diff --git a/config.yaml b/config.yaml index ae589de..ab36d34 100644 --- a/config.yaml +++ b/config.yaml @@ -6,7 +6,7 @@ webserver: service: logging_level: INFO # Logging level for service logger progressbar: True # Whether a progress bar over the pages of a document is displayed while processing - batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously + batch_size: $BATCH_SIZE|16 # Number of images in memory simultaneously verbose: $VERBOSE|True # Service prints document processing progress to stdout run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the service_estimator from diff --git a/image_prediction/extractor_classifier/extractor_classifier.py b/image_prediction/extractor_classifier/extractor_classifier.py index ba97b0e..1b1b5b1 100644 --- a/image_prediction/extractor_classifier/extractor_classifier.py +++ b/image_prediction/extractor_classifier/extractor_classifier.py @@ -1,7 +1,7 @@ from itertools import chain from typing import Iterable -from funcy import chunks +from funcy import chunks, rpartial from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.image_extractor.extractor import ImageExtractor @@ -18,15 +18,16 @@ class ExtractorClassifier: self.classifier = image_classifier self.extractor = image_extractor - def __process_batch(self, batch): + def __process_batch(self, batch, batch_size): images, metadata = zip(*batch) - predictions = self.classifier(images) + predictions = self.classifier(images, batch_size) responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) return responses - def __call__(self, obj, **kwargs) -> Iterable[dict]: + def __call__(self, obj, batch_size=16, **kwargs) -> Iterable[dict]: + image_metadata_pairs = self.extractor(obj, **kwargs) - batches = chunks(16, image_metadata_pairs) - predictions = chain.from_iterable(map(self.__process_batch, batches)) - return predictions + batches = chunks(batch_size, image_metadata_pairs) + predictions = chain.from_iterable(map(rpartial(self.__process_batch, batch_size), batches)) + yield from predictions diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py index 9ca98c9..dae6469 100644 --- a/image_prediction/image_extractor/extractors/parsable.py +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -18,7 +18,7 @@ from image_prediction.utils.generic import lift class ParsablePDFImageExtractor(ImageExtractor): - def __init__(self, verbose=False, tolerance=0): + def __init__(self, verbose=False, tolerance=0, progress_message=None): """ Args: @@ -28,6 +28,7 @@ class ParsablePDFImageExtractor(ImageExtractor): self.doc: fitz.fitz.Document = None self.verbose = verbose self.tolerance = tolerance + self.progress_message = progress_message def extract(self, pdf: bytes, page_range: range = None): self.doc = fitz.Document(stream=pdf) @@ -37,7 +38,12 @@ class ParsablePDFImageExtractor(ImageExtractor): image_metadata_pairs = chain.from_iterable( map( self.__process_images_on_page, - tqdm(pages, desc="Extracting", disable=not self.verbose, total=len(page_range) if page_range else None), + tqdm( + pages, + desc=self.progress_message, + disable=not self.verbose, + total=len(page_range) if page_range else None, + ), ) ) diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 35ce8af..fad4145 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -13,7 +13,7 @@ def load_pipeline(**kwargs): model_loader = get_mlflow_model_loader(MLRUNS_DIR) model_identifier = CONFIG.service.run_id - pipeline = Pipeline(model_loader, model_identifier, **kwargs) + pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs) return pipeline diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index d7bf253..c2b4bb0 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -35,7 +35,7 @@ def process_pdf(pipeline, pdf_path, page_range=None): def main(args): - pipeline = load_pipeline(verbose=False, tolerance=3) + pipeline = load_pipeline(verbose=True, tolerance=3) if os.path.isfile(args.input): pdf_paths = [args.input] diff --git a/src/serve.py b/src/serve.py index b749a85..aad2a30 100644 --- a/src/serve.py +++ b/src/serve.py @@ -17,7 +17,7 @@ def main(): # therefore, we re-load the model (part of the pipeline) every time we process a new document. # https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python logger.debug("Loading pipeline...") - pipeline = load_pipeline(verbose=CONFIG.service.verbose) + pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size) logger.debug("Running pipeline...") return list(pipeline(pdf))