integrate new pyinfra logic

This commit is contained in:
Julius Unverfehrt 2022-08-10 10:37:31 +02:00
parent 4692607834
commit 3225cefaa2
3 changed files with 49 additions and 41 deletions

View File

@ -63,12 +63,6 @@ class Pipeline:
join, # ... the streams by zipping join, # ... the streams by zipping
reformat, # ... the items reformat, # ... the items
) )
self.pipe_for_scanned_pdf = rcompose(
extract_from_scanned,
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
join, # ... the streams by zipping
reformat, # ... the items
)
def __call__(self, pdf: bytes, page_range: range = None): def __call__(self, pdf: bytes, page_range: range = None):
yield from tqdm( yield from tqdm(
@ -77,10 +71,3 @@ class Pipeline:
unit=" images", unit=" images",
disable=not self.verbose, disable=not self.verbose,
) )
def extract_from_scanned(pdf: bytes, bbox_info_per_page: Iterable[dict]):
images = extract_images_per_page(pdf, bbox_info_per_page)
for page_images, page_info in zip(images, bbox_info_per_page):
metadata_per_image = page_info["bboxes"]
metadata_per_image =

View File

@ -4,8 +4,7 @@ from image_prediction.locations import BANNER_FILE
def show_banner(): def show_banner():
with open(BANNER_FILE) as f: banner = load_banner()
banner = "\n" + "".join(f.readlines()) + "\n"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.propagate = False logger.propagate = False
@ -19,3 +18,9 @@ def show_banner():
logger.addHandler(handler) logger.addHandler(handler)
logger.info(banner) logger.info(banner)
def load_banner():
with open(BANNER_FILE) as f:
banner = "\n" + "".join(f.readlines()) + "\n"
return banner

View File

@ -1,36 +1,52 @@
import gzip
import json
import logging import logging
from waitress import serve from image_prediction.config import Config
from image_prediction.locations import CONFIG_FILE
from image_prediction.config import CONFIG
from image_prediction.flask import make_prediction_server
from image_prediction.pipeline import load_pipeline from image_prediction.pipeline import load_pipeline
from image_prediction.utils import get_logger from image_prediction.utils.banner import show_banner, load_banner
from image_prediction.utils.banner import show_banner from pyinfra import config
from pyinfra.queue.queue_manager import QueueManager
from pyinfra.storage.storage import get_storage
PYINFRA_CONFIG = config.get_config()
IMAGE_CONFIG = Config(CONFIG_FILE)
logging.getLogger().addHandler(logging.StreamHandler())
logger = logging.getLogger("main")
logger.setLevel(PYINFRA_CONFIG.logging_level_root)
def process_request(request_message):
pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size)
target_file_extension = request_message["targetFileExtension"]
dossier_id = request_message["dossierId"]
file_id = request_message["fileId"]
storage = get_storage(PYINFRA_CONFIG)
object_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{target_file_extension}")
object_bytes = gzip.decompress(object_bytes)
classifications = list(pipeline(object_bytes))
result = {**request_message, "data": classifications}
response_file_extension = request_message["responseFileExtension"]
storage_bytes = gzip.compress(json.dumps(result).encode("utf-8"))
storage.put_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{response_file_extension}", storage_bytes)
return {"dossierId": dossier_id, "fileId": file_id}
def main(): def main():
logger.info(load_banner())
def predict(pdf): queue_manager = QueueManager(PYINFRA_CONFIG)
# Keras service_estimator.predict stalls when service_estimator was loaded in different process; queue_manager.start_consuming(process_request)
# therefore, we re-load the model (part of the pipeline) every time we process a new document.
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
logger.debug("Loading pipeline...")
pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size)
logger.debug("Running pipeline...")
return list(pipeline(pdf))
prediction_server = make_prediction_server(predict)
serve(prediction_server, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False)
if __name__ == "__main__": if __name__ == '__main__':
logging.basicConfig(level=CONFIG.service.logging_level)
logging.getLogger("PIL").setLevel(logging.ERROR)
logging.getLogger("h5py").setLevel(logging.ERROR)
logging.getLogger("pillow").setLevel(logging.ERROR)
logger = get_logger()
show_banner()
main() main()