This commit is contained in:
cdietrich 2022-03-02 10:24:33 +01:00
parent 372d6645d7
commit 1e5da128f1
2 changed files with 34 additions and 57 deletions

View File

@ -1,13 +1,16 @@
import logging
from itertools import chain
from operator import itemgetter
from typing import List, Dict
from typing import List, Dict, Iterable
import numpy as np
from image_prediction.config import CONFIG
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
class Predictor:
@ -86,3 +89,30 @@ class Predictor:
predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)]
return predictions if probabilities else classes
def extract_image_metadata_pairs(pdf_path: str, **kwargs):
def image_is_large_enough(metadata: dict):
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
def process_chunk(chunk):
images, metadata = zip(*chunk)
predictions = predictor.predict(images, probabilities=True)
return predictions, metadata
def predict(image_metadata_pair_generator):
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
try:
predictions, metadata = predict(image_metadata_pairs)
return predictions, metadata
except ValueError:
return [], []

View File

@ -1,38 +1,15 @@
import argparse
import json
import logging
import tempfile
from itertools import chain
from operator import itemgetter
from typing import Iterable
from flask import Flask, request, jsonify
from waitress import serve
from image_prediction.config import CONFIG
from image_prediction.predictor import Predictor
from image_prediction.predictor import Predictor, extract_image_metadata_pairs, classify_images
from image_prediction.response import build_response
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
def suppress_userwarnings():
import warnings
warnings.filterwarnings("ignore")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--warnings", action="store_true", default=False)
args = parser.parse_args()
return args
def main(args):
if not args.warnings:
suppress_userwarnings()
def main():
predictor = Predictor()
logging.info("Predictor ready.")
@ -56,7 +33,6 @@ def main(args):
pdf = request.data
logging.debug("Running predictor on document...")
# extract images from pdfs
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(pdf)
image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name)
@ -85,33 +61,6 @@ def run_prediction_server(app, mode="development"):
serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port)
def extract_image_metadata_pairs(pdf_path: str, **kwargs):
def image_is_large_enough(metadata: dict):
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
def process_chunk(chunk):
images, metadata = zip(*chunk)
predictions = predictor.predict(images, probabilities=True)
return predictions, metadata
def predict(image_metadata_pair_generator):
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
try:
predictions, metadata = predict(image_metadata_pairs)
return predictions, metadata
except ValueError:
return [], []
if __name__ == "__main__":
logging_level = CONFIG.service.logging_level
logging.basicConfig(level=logging_level)
@ -120,6 +69,4 @@ if __name__ == "__main__":
logging.getLogger("werkzeug").setLevel(logging.ERROR)
logging.getLogger("waitress").setLevel(logging.ERROR)
args = parse_args()
main(args)
main()