diff --git a/config.yaml b/config.yaml index 7569e6b..2a202b4 100644 --- a/config.yaml +++ b/config.yaml @@ -11,5 +11,6 @@ webserver: mode: $SERVER_MODE|production # webserver mode: {development, production} service: - logging_level: DEBUG - batch_size: $BATCH_SIZE|2 # Number of images in memory simultaneously per service instance + logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger + batch_size: $BATCH_SIZE|2 # Number of images in memory simultaneously + verbose: $VERBOSE|True # Service prints document processing progress to stdout diff --git a/fb_detr/predictor.py b/fb_detr/predictor.py index a8f89eb..e083dbf 100644 --- a/fb_detr/predictor.py +++ b/fb_detr/predictor.py @@ -139,6 +139,15 @@ class Predictor: return predictions def predict_pdf(self, pdf: bytes): + def progress(generator): + + page_count = get_page_count(pdf) + batch_count = int(page_count / CONFIG.service.batch_size) + + yield from tqdm( + generator, total=batch_count, position=1, leave=True + ) if CONFIG.service.verbose else generator + def predict_batch(batch_idx, batch): predictions = self.predict(batch) for p in predictions: @@ -146,11 +155,8 @@ class Predictor: return predictions - page_count = get_page_count(pdf) - batch_count = int(page_count / CONFIG.service.batch_size) - page_stream = stream_pages(pdf) page_batches = chunk_iterable(page_stream, CONFIG.service.batch_size) - predictions = list(chain(*starmap(predict_batch, tqdm(enumerate(page_batches), total=batch_count)))) + predictions = list(chain(*starmap(predict_batch, progress(enumerate(page_batches))))) return predictions diff --git a/src/serve.py b/src/serve.py index c1c9deb..fa44ef9 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,16 +1,12 @@ import argparse -import json import logging -from itertools import chain from typing import Callable from flask import Flask, request, jsonify -from pdf2image import pdf2image from waitress import serve from fb_detr.config import CONFIG from fb_detr.utils.estimator import suppress_userwarnings, initialize_predictor -from fb_detr.utils.stream import stream_pages, chunk_iterable def parse_args():