diff --git a/image_prediction/predictor.py b/image_prediction/predictor.py index 3d89a69..d994baf 100644 --- a/image_prediction/predictor.py +++ b/image_prediction/predictor.py @@ -1,13 +1,16 @@ import logging +from itertools import chain from operator import itemgetter -from typing import List, Dict +from typing import List, Dict, Iterable import numpy as np from image_prediction.config import CONFIG from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle +from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader +from incl.redai_image.redai.redai.utils.shared import chunk_iterable class Predictor: @@ -86,3 +89,30 @@ class Predictor: predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)] return predictions if probabilities else classes + + +def extract_image_metadata_pairs(pdf_path: str, **kwargs): + def image_is_large_enough(metadata: dict): + x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) + + return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 + + yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) + + +def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): + def process_chunk(chunk): + images, metadata = zip(*chunk) + predictions = predictor.predict(images, probabilities=True) + return predictions, metadata + + def predict(image_metadata_pair_generator): + chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size) + return map(chain.from_iterable, zip(*map(process_chunk, chunks))) + + try: + predictions, metadata = predict(image_metadata_pairs) + return predictions, metadata + + except ValueError: + return [], [] diff --git a/src/serve.py b/src/serve.py index 4c292d3..bc6bae2 100644 --- a/src/serve.py +++ b/src/serve.py @@ -1,38 +1,15 @@ -import argparse -import json import logging import tempfile -from itertools import chain -from operator import itemgetter -from typing import Iterable from flask import Flask, request, jsonify from waitress import serve from image_prediction.config import CONFIG -from image_prediction.predictor import Predictor +from image_prediction.predictor import Predictor, extract_image_metadata_pairs, classify_images from image_prediction.response import build_response -from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch -from incl.redai_image.redai.redai.utils.shared import chunk_iterable -def suppress_userwarnings(): - import warnings - - warnings.filterwarnings("ignore") - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--warnings", action="store_true", default=False) - args = parser.parse_args() - - return args - - -def main(args): - if not args.warnings: - suppress_userwarnings() +def main(): predictor = Predictor() logging.info("Predictor ready.") @@ -56,7 +33,6 @@ def main(args): pdf = request.data logging.debug("Running predictor on document...") - # extract images from pdfs with tempfile.NamedTemporaryFile() as tmp_file: tmp_file.write(pdf) image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name) @@ -85,33 +61,6 @@ def run_prediction_server(app, mode="development"): serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port) -def extract_image_metadata_pairs(pdf_path: str, **kwargs): - def image_is_large_enough(metadata: dict): - x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) - - return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 - - yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) - - -def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): - def process_chunk(chunk): - images, metadata = zip(*chunk) - predictions = predictor.predict(images, probabilities=True) - return predictions, metadata - - def predict(image_metadata_pair_generator): - chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size) - return map(chain.from_iterable, zip(*map(process_chunk, chunks))) - - try: - predictions, metadata = predict(image_metadata_pairs) - return predictions, metadata - - except ValueError: - return [], [] - - if __name__ == "__main__": logging_level = CONFIG.service.logging_level logging.basicConfig(level=logging_level) @@ -120,6 +69,4 @@ if __name__ == "__main__": logging.getLogger("werkzeug").setLevel(logging.ERROR) logging.getLogger("waitress").setLevel(logging.ERROR) - args = parse_args() - - main(args) + main()