batch size and progressbar message forwarding
This commit is contained in:
parent
2619831986
commit
1fcec06e91
@ -6,7 +6,7 @@ webserver:
|
||||
service:
|
||||
logging_level: INFO # Logging level for service logger
|
||||
progressbar: True # Whether a progress bar over the pages of a document is displayed while processing
|
||||
batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously
|
||||
batch_size: $BATCH_SIZE|16 # Number of images in memory simultaneously
|
||||
verbose: $VERBOSE|True # Service prints document processing progress to stdout
|
||||
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the service_estimator from
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from itertools import chain
|
||||
from typing import Iterable
|
||||
|
||||
from funcy import chunks
|
||||
from funcy import chunks, rpartial
|
||||
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.image_extractor.extractor import ImageExtractor
|
||||
@ -18,15 +18,16 @@ class ExtractorClassifier:
|
||||
self.classifier = image_classifier
|
||||
self.extractor = image_extractor
|
||||
|
||||
def __process_batch(self, batch):
|
||||
def __process_batch(self, batch, batch_size):
|
||||
images, metadata = zip(*batch)
|
||||
|
||||
predictions = self.classifier(images)
|
||||
predictions = self.classifier(images, batch_size)
|
||||
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
||||
return responses
|
||||
|
||||
def __call__(self, obj, **kwargs) -> Iterable[dict]:
|
||||
def __call__(self, obj, batch_size=16, **kwargs) -> Iterable[dict]:
|
||||
|
||||
image_metadata_pairs = self.extractor(obj, **kwargs)
|
||||
batches = chunks(16, image_metadata_pairs)
|
||||
predictions = chain.from_iterable(map(self.__process_batch, batches))
|
||||
return predictions
|
||||
batches = chunks(batch_size, image_metadata_pairs)
|
||||
predictions = chain.from_iterable(map(rpartial(self.__process_batch, batch_size), batches))
|
||||
yield from predictions
|
||||
|
||||
@ -18,7 +18,7 @@ from image_prediction.utils.generic import lift
|
||||
|
||||
|
||||
class ParsablePDFImageExtractor(ImageExtractor):
|
||||
def __init__(self, verbose=False, tolerance=0):
|
||||
def __init__(self, verbose=False, tolerance=0, progress_message=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -28,6 +28,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
self.doc: fitz.fitz.Document = None
|
||||
self.verbose = verbose
|
||||
self.tolerance = tolerance
|
||||
self.progress_message = progress_message
|
||||
|
||||
def extract(self, pdf: bytes, page_range: range = None):
|
||||
self.doc = fitz.Document(stream=pdf)
|
||||
@ -37,7 +38,12 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
||||
image_metadata_pairs = chain.from_iterable(
|
||||
map(
|
||||
self.__process_images_on_page,
|
||||
tqdm(pages, desc="Extracting", disable=not self.verbose, total=len(page_range) if page_range else None),
|
||||
tqdm(
|
||||
pages,
|
||||
desc=self.progress_message,
|
||||
disable=not self.verbose,
|
||||
total=len(page_range) if page_range else None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ def load_pipeline(**kwargs):
|
||||
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
||||
model_identifier = CONFIG.service.run_id
|
||||
|
||||
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
|
||||
pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ def process_pdf(pipeline, pdf_path, page_range=None):
|
||||
|
||||
|
||||
def main(args):
|
||||
pipeline = load_pipeline(verbose=False, tolerance=3)
|
||||
pipeline = load_pipeline(verbose=True, tolerance=3)
|
||||
|
||||
if os.path.isfile(args.input):
|
||||
pdf_paths = [args.input]
|
||||
|
||||
@ -17,7 +17,7 @@ def main():
|
||||
# therefore, we re-load the model (part of the pipeline) every time we process a new document.
|
||||
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
|
||||
logger.debug("Loading pipeline...")
|
||||
pipeline = load_pipeline(verbose=CONFIG.service.verbose)
|
||||
pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size)
|
||||
logger.debug("Running pipeline...")
|
||||
return list(pipeline(pdf))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user