batch size and progressbar message forwarding

This commit is contained in:
Matthias Bisping 2022-04-19 15:09:14 +02:00
parent 2619831986
commit 1fcec06e91
6 changed files with 20 additions and 13 deletions

View File

@ -6,7 +6,7 @@ webserver:
service:
logging_level: INFO # Logging level for service logger
progressbar: True # Whether a progress bar over the pages of a document is displayed while processing
batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously
batch_size: $BATCH_SIZE|16 # Number of images in memory simultaneously
verbose: $VERBOSE|True # Service prints document processing progress to stdout
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the service_estimator from

View File

@ -1,7 +1,7 @@
from itertools import chain
from typing import Iterable
from funcy import chunks
from funcy import chunks, rpartial
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor
@ -18,15 +18,16 @@ class ExtractorClassifier:
self.classifier = image_classifier
self.extractor = image_extractor
def __process_batch(self, batch):
def __process_batch(self, batch, batch_size):
images, metadata = zip(*batch)
predictions = self.classifier(images)
predictions = self.classifier(images, batch_size)
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
return responses
def __call__(self, obj, **kwargs) -> Iterable[dict]:
def __call__(self, obj, batch_size=16, **kwargs) -> Iterable[dict]:
image_metadata_pairs = self.extractor(obj, **kwargs)
batches = chunks(16, image_metadata_pairs)
predictions = chain.from_iterable(map(self.__process_batch, batches))
return predictions
batches = chunks(batch_size, image_metadata_pairs)
predictions = chain.from_iterable(map(rpartial(self.__process_batch, batch_size), batches))
yield from predictions

View File

@ -18,7 +18,7 @@ from image_prediction.utils.generic import lift
class ParsablePDFImageExtractor(ImageExtractor):
def __init__(self, verbose=False, tolerance=0):
def __init__(self, verbose=False, tolerance=0, progress_message=None):
"""
Args:
@ -28,6 +28,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
self.doc: fitz.fitz.Document = None
self.verbose = verbose
self.tolerance = tolerance
self.progress_message = progress_message
def extract(self, pdf: bytes, page_range: range = None):
self.doc = fitz.Document(stream=pdf)
@ -37,7 +38,12 @@ class ParsablePDFImageExtractor(ImageExtractor):
image_metadata_pairs = chain.from_iterable(
map(
self.__process_images_on_page,
tqdm(pages, desc="Extracting", disable=not self.verbose, total=len(page_range) if page_range else None),
tqdm(
pages,
desc=self.progress_message,
disable=not self.verbose,
total=len(page_range) if page_range else None,
),
)
)

View File

@ -13,7 +13,7 @@ def load_pipeline(**kwargs):
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
model_identifier = CONFIG.service.run_id
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs)
return pipeline

View File

@ -35,7 +35,7 @@ def process_pdf(pipeline, pdf_path, page_range=None):
def main(args):
pipeline = load_pipeline(verbose=False, tolerance=3)
pipeline = load_pipeline(verbose=True, tolerance=3)
if os.path.isfile(args.input):
pdf_paths = [args.input]

View File

@ -17,7 +17,7 @@ def main():
# therefore, we re-load the model (part of the pipeline) every time we process a new document.
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
logger.debug("Loading pipeline...")
pipeline = load_pipeline(verbose=CONFIG.service.verbose)
pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size)
logger.debug("Running pipeline...")
return list(pipeline(pdf))