batch size and progressbar message forwarding
This commit is contained in:
parent
2619831986
commit
1fcec06e91
@ -6,7 +6,7 @@ webserver:
|
|||||||
service:
|
service:
|
||||||
logging_level: INFO # Logging level for service logger
|
logging_level: INFO # Logging level for service logger
|
||||||
progressbar: True # Whether a progress bar over the pages of a document is displayed while processing
|
progressbar: True # Whether a progress bar over the pages of a document is displayed while processing
|
||||||
batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously
|
batch_size: $BATCH_SIZE|16 # Number of images in memory simultaneously
|
||||||
verbose: $VERBOSE|True # Service prints document processing progress to stdout
|
verbose: $VERBOSE|True # Service prints document processing progress to stdout
|
||||||
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the service_estimator from
|
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the service_estimator from
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from funcy import chunks
|
from funcy import chunks, rpartial
|
||||||
|
|
||||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
from image_prediction.image_extractor.extractor import ImageExtractor
|
from image_prediction.image_extractor.extractor import ImageExtractor
|
||||||
@ -18,15 +18,16 @@ class ExtractorClassifier:
|
|||||||
self.classifier = image_classifier
|
self.classifier = image_classifier
|
||||||
self.extractor = image_extractor
|
self.extractor = image_extractor
|
||||||
|
|
||||||
def __process_batch(self, batch):
|
def __process_batch(self, batch, batch_size):
|
||||||
images, metadata = zip(*batch)
|
images, metadata = zip(*batch)
|
||||||
|
|
||||||
predictions = self.classifier(images)
|
predictions = self.classifier(images, batch_size)
|
||||||
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
def __call__(self, obj, **kwargs) -> Iterable[dict]:
|
def __call__(self, obj, batch_size=16, **kwargs) -> Iterable[dict]:
|
||||||
|
|
||||||
image_metadata_pairs = self.extractor(obj, **kwargs)
|
image_metadata_pairs = self.extractor(obj, **kwargs)
|
||||||
batches = chunks(16, image_metadata_pairs)
|
batches = chunks(batch_size, image_metadata_pairs)
|
||||||
predictions = chain.from_iterable(map(self.__process_batch, batches))
|
predictions = chain.from_iterable(map(rpartial(self.__process_batch, batch_size), batches))
|
||||||
return predictions
|
yield from predictions
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from image_prediction.utils.generic import lift
|
|||||||
|
|
||||||
|
|
||||||
class ParsablePDFImageExtractor(ImageExtractor):
|
class ParsablePDFImageExtractor(ImageExtractor):
|
||||||
def __init__(self, verbose=False, tolerance=0):
|
def __init__(self, verbose=False, tolerance=0, progress_message=None):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -28,6 +28,7 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
self.doc: fitz.fitz.Document = None
|
self.doc: fitz.fitz.Document = None
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.tolerance = tolerance
|
self.tolerance = tolerance
|
||||||
|
self.progress_message = progress_message
|
||||||
|
|
||||||
def extract(self, pdf: bytes, page_range: range = None):
|
def extract(self, pdf: bytes, page_range: range = None):
|
||||||
self.doc = fitz.Document(stream=pdf)
|
self.doc = fitz.Document(stream=pdf)
|
||||||
@ -37,7 +38,12 @@ class ParsablePDFImageExtractor(ImageExtractor):
|
|||||||
image_metadata_pairs = chain.from_iterable(
|
image_metadata_pairs = chain.from_iterable(
|
||||||
map(
|
map(
|
||||||
self.__process_images_on_page,
|
self.__process_images_on_page,
|
||||||
tqdm(pages, desc="Extracting", disable=not self.verbose, total=len(page_range) if page_range else None),
|
tqdm(
|
||||||
|
pages,
|
||||||
|
desc=self.progress_message,
|
||||||
|
disable=not self.verbose,
|
||||||
|
total=len(page_range) if page_range else None,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,7 @@ def load_pipeline(**kwargs):
|
|||||||
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
||||||
model_identifier = CONFIG.service.run_id
|
model_identifier = CONFIG.service.run_id
|
||||||
|
|
||||||
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
|
pipeline = Pipeline(model_loader, model_identifier, progress_message="Processing document", **kwargs)
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
|
|||||||
@ -35,7 +35,7 @@ def process_pdf(pipeline, pdf_path, page_range=None):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
pipeline = load_pipeline(verbose=False, tolerance=3)
|
pipeline = load_pipeline(verbose=True, tolerance=3)
|
||||||
|
|
||||||
if os.path.isfile(args.input):
|
if os.path.isfile(args.input):
|
||||||
pdf_paths = [args.input]
|
pdf_paths = [args.input]
|
||||||
|
|||||||
@ -17,7 +17,7 @@ def main():
|
|||||||
# therefore, we re-load the model (part of the pipeline) every time we process a new document.
|
# therefore, we re-load the model (part of the pipeline) every time we process a new document.
|
||||||
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
|
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
|
||||||
logger.debug("Loading pipeline...")
|
logger.debug("Loading pipeline...")
|
||||||
pipeline = load_pipeline(verbose=CONFIG.service.verbose)
|
pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size)
|
||||||
logger.debug("Running pipeline...")
|
logger.debug("Running pipeline...")
|
||||||
return list(pipeline(pdf))
|
return list(pipeline(pdf))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user