integrate new pyinfra logic
This commit is contained in:
parent
4692607834
commit
3225cefaa2
@ -63,12 +63,6 @@ class Pipeline:
|
||||
join, # ... the streams by zipping
|
||||
reformat, # ... the items
|
||||
)
|
||||
self.pipe_for_scanned_pdf = rcompose(
|
||||
extract_from_scanned,
|
||||
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
|
||||
join, # ... the streams by zipping
|
||||
reformat, # ... the items
|
||||
)
|
||||
|
||||
def __call__(self, pdf: bytes, page_range: range = None):
|
||||
yield from tqdm(
|
||||
@ -77,10 +71,3 @@ class Pipeline:
|
||||
unit=" images",
|
||||
disable=not self.verbose,
|
||||
)
|
||||
|
||||
|
||||
def extract_from_scanned(pdf: bytes, bbox_info_per_page: Iterable[dict]):
|
||||
images = extract_images_per_page(pdf, bbox_info_per_page)
|
||||
for page_images, page_info in zip(images, bbox_info_per_page):
|
||||
metadata_per_image = page_info["bboxes"]
|
||||
metadata_per_image =
|
||||
@ -4,8 +4,7 @@ from image_prediction.locations import BANNER_FILE
|
||||
|
||||
|
||||
def show_banner():
|
||||
with open(BANNER_FILE) as f:
|
||||
banner = "\n" + "".join(f.readlines()) + "\n"
|
||||
banner = load_banner()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.propagate = False
|
||||
@ -19,3 +18,9 @@ def show_banner():
|
||||
logger.addHandler(handler)
|
||||
|
||||
logger.info(banner)
|
||||
|
||||
|
||||
def load_banner():
|
||||
with open(BANNER_FILE) as f:
|
||||
banner = "\n" + "".join(f.readlines()) + "\n"
|
||||
return banner
|
||||
|
||||
68
src/serve.py
68
src/serve.py
@ -1,36 +1,52 @@
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
|
||||
from waitress import serve
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.flask import make_prediction_server
|
||||
from image_prediction.config import Config
|
||||
from image_prediction.locations import CONFIG_FILE
|
||||
from image_prediction.pipeline import load_pipeline
|
||||
from image_prediction.utils import get_logger
|
||||
from image_prediction.utils.banner import show_banner
|
||||
from image_prediction.utils.banner import show_banner, load_banner
|
||||
from pyinfra import config
|
||||
from pyinfra.queue.queue_manager import QueueManager
|
||||
from pyinfra.storage.storage import get_storage
|
||||
|
||||
PYINFRA_CONFIG = config.get_config()
|
||||
IMAGE_CONFIG = Config(CONFIG_FILE)
|
||||
|
||||
logging.getLogger().addHandler(logging.StreamHandler())
|
||||
logger = logging.getLogger("main")
|
||||
logger.setLevel(PYINFRA_CONFIG.logging_level_root)
|
||||
|
||||
|
||||
def process_request(request_message):
|
||||
pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size)
|
||||
|
||||
target_file_extension = request_message["targetFileExtension"]
|
||||
dossier_id = request_message["dossierId"]
|
||||
file_id = request_message["fileId"]
|
||||
|
||||
storage = get_storage(PYINFRA_CONFIG)
|
||||
|
||||
object_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{target_file_extension}")
|
||||
object_bytes = gzip.decompress(object_bytes)
|
||||
|
||||
classifications = list(pipeline(object_bytes))
|
||||
|
||||
result = {**request_message, "data": classifications}
|
||||
|
||||
response_file_extension = request_message["responseFileExtension"]
|
||||
storage_bytes = gzip.compress(json.dumps(result).encode("utf-8"))
|
||||
storage.put_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{response_file_extension}", storage_bytes)
|
||||
|
||||
return {"dossierId": dossier_id, "fileId": file_id}
|
||||
|
||||
|
||||
def main():
|
||||
logger.info(load_banner())
|
||||
|
||||
def predict(pdf):
|
||||
# Keras service_estimator.predict stalls when service_estimator was loaded in different process;
|
||||
# therefore, we re-load the model (part of the pipeline) every time we process a new document.
|
||||
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
|
||||
logger.debug("Loading pipeline...")
|
||||
pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size)
|
||||
logger.debug("Running pipeline...")
|
||||
return list(pipeline(pdf))
|
||||
|
||||
prediction_server = make_prediction_server(predict)
|
||||
serve(prediction_server, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False)
|
||||
queue_manager = QueueManager(PYINFRA_CONFIG)
|
||||
queue_manager.start_consuming(process_request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=CONFIG.service.logging_level)
|
||||
logging.getLogger("PIL").setLevel(logging.ERROR)
|
||||
logging.getLogger("h5py").setLevel(logging.ERROR)
|
||||
logging.getLogger("pillow").setLevel(logging.ERROR)
|
||||
logger = get_logger()
|
||||
|
||||
show_banner()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user