update server logic for new pyinfra, add extraction from scanned PDF with figure detection logic
This commit is contained in:
parent
3225cefaa2
commit
287b0ebc8a
@ -1,3 +1,5 @@
|
||||
from typing import Iterable
|
||||
|
||||
from funcy import juxt
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
@ -7,13 +9,17 @@ from image_prediction.encoder.encoders.hash_encoder import HashEncoder
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||
from image_prediction.model_loader.loader import ModelLoader
|
||||
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
|
||||
from image_prediction.transformer.transformers.response import ResponseTransformer
|
||||
from pdf2img.default_objects.image import ImagePlus
|
||||
from pdf2img.extraction import extract_images_via_metadata
|
||||
|
||||
|
||||
def get_mlflow_model_loader(mlruns_dir):
|
||||
@ -41,3 +47,24 @@ def get_formatter():
|
||||
|
||||
def get_encoder():
|
||||
return HashEncoder()
|
||||
|
||||
|
||||
def extract_images_via_metadata_and_format_to_image_metadata_pair(pdf: bytes, metadata_per_image: Iterable[dict]):
|
||||
image_pluses = extract_images_via_metadata(pdf, metadata_per_image)
|
||||
|
||||
def reformat(image: ImagePlus):
|
||||
enum_metadata = {
|
||||
Info.PAGE_WIDTH: image.info.pageInfo.width,
|
||||
Info.PAGE_HEIGHT: image.info.pageInfo.height,
|
||||
Info.PAGE_IDX: image.info.pageInfo.number,
|
||||
Info.ALPHA: image.info.alpha,
|
||||
Info.WIDTH: image.info.boundingBox.width,
|
||||
Info.HEIGHT: image.info.boundingBox.height,
|
||||
Info.X1: image.info.boundingBox.x0,
|
||||
Info.X2: image.info.boundingBox.x1,
|
||||
Info.Y1: image.info.boundingBox.y0,
|
||||
Info.Y2: image.info.boundingBox.y1,
|
||||
}
|
||||
return ImageMetadataPair(image.aspil(), enum_metadata)
|
||||
|
||||
yield from map(reformat, image_pluses)
|
||||
|
||||
@ -13,10 +13,10 @@ from image_prediction.default_objects import (
|
||||
get_image_classifier,
|
||||
get_extractor,
|
||||
get_encoder,
|
||||
extract_images_via_metadata_and_format_to_image_metadata_pair,
|
||||
)
|
||||
from image_prediction.locations import MLRUNS_DIR
|
||||
from image_prediction.utils.generic import lift, starlift
|
||||
from pdf2img.extraction import extract_images_per_page
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
@ -63,6 +63,13 @@ class Pipeline:
|
||||
join, # ... the streams by zipping
|
||||
reformat, # ... the items
|
||||
)
|
||||
self.pipe2 = rcompose(
|
||||
extract_images_via_metadata_and_format_to_image_metadata_pair,
|
||||
split, # ... into an image stream and a metadata stream
|
||||
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
|
||||
join, # ... the streams by zipping
|
||||
reformat, # ... the items
|
||||
)
|
||||
|
||||
def __call__(self, pdf: bytes, page_range: range = None):
|
||||
yield from tqdm(
|
||||
@ -71,3 +78,11 @@ class Pipeline:
|
||||
unit=" images",
|
||||
disable=not self.verbose,
|
||||
)
|
||||
|
||||
def extract_via_metadata(self, pdf: bytes, metadata_per_page: Iterable[dict]):
|
||||
yield from tqdm(
|
||||
self.pipe2(pdf, metadata_per_page),
|
||||
desc="Processing images from document",
|
||||
unit=" images",
|
||||
disable=not self.verbose,
|
||||
)
|
||||
|
||||
@ -1 +1 @@
|
||||
Subproject commit 17965e4578818b16cbd1638dfde1c58cbea55954
|
||||
Subproject commit 699568875683ba727ec9759c8bea85e0d3e1d369
|
||||
22
src/serve.py
22
src/serve.py
@ -1,11 +1,12 @@
|
||||
import gzip
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
|
||||
from image_prediction.config import Config
|
||||
from image_prediction.locations import CONFIG_FILE
|
||||
from image_prediction.pipeline import load_pipeline
|
||||
from image_prediction.utils.banner import show_banner, load_banner
|
||||
from image_prediction.utils.banner import load_banner
|
||||
from pyinfra import config
|
||||
from pyinfra.queue.queue_manager import QueueManager
|
||||
from pyinfra.storage.storage import get_storage
|
||||
@ -30,13 +31,26 @@ def process_request(request_message):
|
||||
object_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{target_file_extension}")
|
||||
object_bytes = gzip.decompress(object_bytes)
|
||||
|
||||
classifications = list(pipeline(object_bytes))
|
||||
try: # TODO: add figure detection file target to request message to avoid this
|
||||
metadata_bytes = storage.get_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.FIGURE.json.gz")
|
||||
metadata_bytes = gzip.decompress(metadata_bytes)
|
||||
metadata = json.load(io.BytesIO(metadata_bytes))
|
||||
logger.info("Metadata aquired")
|
||||
except:
|
||||
metadata = None
|
||||
|
||||
if metadata:
|
||||
classifications = list(pipeline.extract_via_metadata(object_bytes, metadata_per_page=metadata["data"]))
|
||||
else:
|
||||
classifications = list(pipeline(object_bytes))
|
||||
|
||||
result = {**request_message, "data": classifications}
|
||||
|
||||
response_file_extension = request_message["responseFileExtension"]
|
||||
storage_bytes = gzip.compress(json.dumps(result).encode("utf-8"))
|
||||
storage.put_object(PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{response_file_extension}", storage_bytes)
|
||||
storage.put_object(
|
||||
PYINFRA_CONFIG.storage_bucket, f"{dossier_id}/{file_id}.{response_file_extension}", storage_bytes
|
||||
)
|
||||
|
||||
return {"dossierId": dossier_id, "fileId": file_id}
|
||||
|
||||
@ -48,5 +62,5 @@ def main():
|
||||
queue_manager.start_consuming(process_request)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user