integrate new pyinfra logic
This commit is contained in:
parent
4692607834
commit
3225cefaa2
@ -63,12 +63,6 @@ class Pipeline:
|
|||||||
join, # ... the streams by zipping
|
join, # ... the streams by zipping
|
||||||
reformat, # ... the items
|
reformat, # ... the items
|
||||||
)
|
)
|
||||||
self.pipe_for_scanned_pdf = rcompose(
|
|
||||||
extract_from_scanned,
|
|
||||||
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
|
|
||||||
join, # ... the streams by zipping
|
|
||||||
reformat, # ... the items
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, pdf: bytes, page_range: range = None):
|
def __call__(self, pdf: bytes, page_range: range = None):
|
||||||
yield from tqdm(
|
yield from tqdm(
|
||||||
@ -77,10 +71,3 @@ class Pipeline:
|
|||||||
unit=" images",
|
unit=" images",
|
||||||
disable=not self.verbose,
|
disable=not self.verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_from_scanned(pdf: bytes, bbox_info_per_page: Iterable[dict]):
|
|
||||||
images = extract_images_per_page(pdf, bbox_info_per_page)
|
|
||||||
for page_images, page_info in zip(images, bbox_info_per_page):
|
|
||||||
metadata_per_image = page_info["bboxes"]
|
|
||||||
metadata_per_image =
|
|
||||||
@ -4,8 +4,7 @@ from image_prediction.locations import BANNER_FILE
|
|||||||
|
|
||||||
|
|
||||||
def show_banner():
|
def show_banner():
|
||||||
with open(BANNER_FILE) as f:
|
banner = load_banner()
|
||||||
banner = "\n" + "".join(f.readlines()) + "\n"
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
@ -19,3 +18,9 @@ def show_banner():
|
|||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
logger.info(banner)
|
logger.info(banner)
|
||||||
|
|
||||||
|
|
||||||
|
def load_banner():
|
||||||
|
with open(BANNER_FILE) as f:
|
||||||
|
banner = "\n" + "".join(f.readlines()) + "\n"
|
||||||
|
return banner
|
||||||
|
|||||||
68
src/serve.py
68
src/serve.py
@ -1,36 +1,52 @@
|
|||||||
|
import gzip
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from waitress import serve
|
from image_prediction.config import Config
|
||||||
|
from image_prediction.locations import CONFIG_FILE
|
||||||
from image_prediction.config import CONFIG
|
|
||||||
from image_prediction.flask import make_prediction_server
|
|
||||||
from image_prediction.pipeline import load_pipeline
|
from image_prediction.pipeline import load_pipeline
|
||||||
from image_prediction.utils import get_logger
|
from image_prediction.utils.banner import show_banner, load_banner
|
||||||
from image_prediction.utils.banner import show_banner
|
from pyinfra import config
|
||||||
|
from pyinfra.queue.queue_manager import QueueManager
|
||||||
|
from pyinfra.storage.storage import get_storage
|
||||||
|
|
||||||
|
PYINFRA_CONFIG = config.get_config()
|
||||||
|
IMAGE_CONFIG = Config(CONFIG_FILE)
|
||||||
|
|
||||||
|
logging.getLogger().addHandler(logging.StreamHandler())
|
||||||
|
logger = logging.getLogger("main")
|
||||||
|
logger.setLevel(PYINFRA_CONFIG.logging_level_root)
|
||||||
|
|
||||||
|
|
||||||
|
def process_request(request_message):
|
||||||
|
pipeline = load_pipeline(verbose=IMAGE_CONFIG.service.verbose, batch_size=IMAGE_CONFIG.service.batch_size)
|
||||||
|
|
||||||
|
target_file_extension = request_message["targetFileExtension"]
|
||||||
|
dossier_id = request_message["dossierId"]
|
||||||
|
file_id = request_message["fileId"]
|
||||||
|
|
||||||
|
storage = get_storage(PYINFRA_CONFIG)
|
||||||
|
|
||||||
|
object_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{target_file_extension}")
|
||||||
|
object_bytes = gzip.decompress(object_bytes)
|
||||||
|
|
||||||
|
classifications = list(pipeline(object_bytes))
|
||||||
|
|
||||||
|
result = {**request_message, "data": classifications}
|
||||||
|
|
||||||
|
response_file_extension = request_message["responseFileExtension"]
|
||||||
|
storage_bytes = gzip.compress(json.dumps(result).encode("utf-8"))
|
||||||
|
storage.put_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{response_file_extension}", storage_bytes)
|
||||||
|
|
||||||
|
return {"dossierId": dossier_id, "fileId": file_id}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
logger.info(load_banner())
|
||||||
|
|
||||||
def predict(pdf):
|
queue_manager = QueueManager(PYINFRA_CONFIG)
|
||||||
# Keras service_estimator.predict stalls when service_estimator was loaded in different process;
|
queue_manager.start_consuming(process_request)
|
||||||
# therefore, we re-load the model (part of the pipeline) every time we process a new document.
|
|
||||||
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
|
|
||||||
logger.debug("Loading pipeline...")
|
|
||||||
pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size)
|
|
||||||
logger.debug("Running pipeline...")
|
|
||||||
return list(pipeline(pdf))
|
|
||||||
|
|
||||||
prediction_server = make_prediction_server(predict)
|
|
||||||
serve(prediction_server, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
logging.basicConfig(level=CONFIG.service.logging_level)
|
|
||||||
logging.getLogger("PIL").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("h5py").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("pillow").setLevel(logging.ERROR)
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
show_banner()
|
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user