refactor
This commit is contained in:
parent
372d6645d7
commit
1e5da128f1
@ -1,13 +1,16 @@
|
||||
import logging
|
||||
from itertools import chain
|
||||
from operator import itemgetter
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
|
||||
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
|
||||
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
||||
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
|
||||
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
|
||||
|
||||
|
||||
class Predictor:
|
||||
@ -86,3 +89,30 @@ class Predictor:
|
||||
predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)]
|
||||
|
||||
return predictions if probabilities else classes
|
||||
|
||||
|
||||
def extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
||||
def image_is_large_enough(metadata: dict):
|
||||
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
||||
|
||||
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
||||
|
||||
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
||||
|
||||
|
||||
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
||||
def process_chunk(chunk):
|
||||
images, metadata = zip(*chunk)
|
||||
predictions = predictor.predict(images, probabilities=True)
|
||||
return predictions, metadata
|
||||
|
||||
def predict(image_metadata_pair_generator):
|
||||
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
||||
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
||||
|
||||
try:
|
||||
predictions, metadata = predict(image_metadata_pairs)
|
||||
return predictions, metadata
|
||||
|
||||
except ValueError:
|
||||
return [], []
|
||||
|
||||
59
src/serve.py
59
src/serve.py
@ -1,38 +1,15 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from itertools import chain
|
||||
from operator import itemgetter
|
||||
from typing import Iterable
|
||||
|
||||
from flask import Flask, request, jsonify
|
||||
from waitress import serve
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.predictor import Predictor
|
||||
from image_prediction.predictor import Predictor, extract_image_metadata_pairs, classify_images
|
||||
from image_prediction.response import build_response
|
||||
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
||||
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
|
||||
|
||||
|
||||
def suppress_userwarnings():
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--warnings", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
if not args.warnings:
|
||||
suppress_userwarnings()
|
||||
def main():
|
||||
|
||||
predictor = Predictor()
|
||||
logging.info("Predictor ready.")
|
||||
@ -56,7 +33,6 @@ def main(args):
|
||||
pdf = request.data
|
||||
|
||||
logging.debug("Running predictor on document...")
|
||||
# extract images from pdfs
|
||||
with tempfile.NamedTemporaryFile() as tmp_file:
|
||||
tmp_file.write(pdf)
|
||||
image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name)
|
||||
@ -85,33 +61,6 @@ def run_prediction_server(app, mode="development"):
|
||||
serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port)
|
||||
|
||||
|
||||
def extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
||||
def image_is_large_enough(metadata: dict):
|
||||
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
||||
|
||||
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
||||
|
||||
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
||||
|
||||
|
||||
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
||||
def process_chunk(chunk):
|
||||
images, metadata = zip(*chunk)
|
||||
predictions = predictor.predict(images, probabilities=True)
|
||||
return predictions, metadata
|
||||
|
||||
def predict(image_metadata_pair_generator):
|
||||
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
||||
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
||||
|
||||
try:
|
||||
predictions, metadata = predict(image_metadata_pairs)
|
||||
return predictions, metadata
|
||||
|
||||
except ValueError:
|
||||
return [], []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging_level = CONFIG.service.logging_level
|
||||
logging.basicConfig(level=logging_level)
|
||||
@ -120,6 +69,4 @@ if __name__ == "__main__":
|
||||
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
||||
logging.getLogger("waitress").setLevel(logging.ERROR)
|
||||
|
||||
args = parse_args()
|
||||
|
||||
main(args)
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user