RED-3501: adapt service-container to image-service-v2

This commit is contained in:
cdietrich 2022-03-01 16:48:03 +01:00
parent 684aca364f
commit 42ae5793e0
14 changed files with 191 additions and 94 deletions

2
.gitignore vendored
View File

@ -171,3 +171,5 @@ fabric.properties
.idea/codestream.xml
# End of https://www.toptal.com/developers/gitignore/api/linux,pycharm
/image_prediction/data/mlruns/
/data/mlruns/

2
.gitmodules vendored
View File

@ -1,3 +1,3 @@
[submodule "incl/redai_image"]
path = incl/redai_image
url = ssh://git@git.iqser.com:2222/rr/redai_image.git
url = ssh://git@git.iqser.com:2222/rr/redai_image.git

View File

@ -11,13 +11,14 @@ COPY image_prediction ./image_prediction
COPY ./setup.py ./setup.py
COPY ./requirements.txt ./requirements.txt
COPY ./config.yaml ./config.yaml
COPY data data
# Install dependencies differing from base image.
RUN python3 -m pip install -r requirements.txt
RUN python3 -m pip install -e .
WORKDIR /app/service/incl/redai_image
WORKDIR /app/service/incl/redai_image/redai
RUN python3 -m pip install -e .
WORKDIR /app/service

View File

@ -10,11 +10,13 @@ RUN python -m pip install --upgrade pip
# Make a directory for the service files and copy the service repo into the container.
WORKDIR /app/service
COPY ./requirements.txt ./requirements.txt
COPY ./data ./data
COPY ./incl/redai_image/redai/requirements_user.txt ./requirements_redai.txt
# Install dependencies.
RUN python3 -m pip install -r requirements.txt
RUN python3 -m pip install -r requirements_redai.txt
# Make a new container and copy all relevant files over to filter out temporary files
# produced during setup to reduce the final container's size.
FROM python:3.8

View File

@ -1,10 +1,3 @@
estimator:
checkpoint: checkpoint.pth
classes: ["logo", "other", "formula", "signature", "handwriting_other"]
rejection_class: "other"
threshold: .5
device: cpu
webserver:
host: $SERVER_HOST|"127.0.0.1" # webserver address
port: $SERVER_PORT|5000 # webserver port
@ -14,3 +7,22 @@ service:
logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger
batch_size: $BATCH_SIZE|2 # Number of images in memory simultaneously
verbose: $VERBOSE|True # Service prints document processing progress to stdout
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the model from
# These variables control filters that are applied to either images, image metadata or model predictions. The filter
# result values are reported in the service responses. For convenience the response to a request contains a
# "filters.allPassed" field, which is set to false if any of the filters returned values did not meet its specified
# required value.
filters:
image_to_page_quotient: # Image size to page size ratio (ratio of geometric means of areas)
min: $MIN_REL_IMAGE_SIZE|0.05 # Minimum permissible
max: $MAX_REL_IMAGE_SIZE|0.75 # Maximum permissible
image_width_to_height_quotient: # Image width to height ratio
min: $MIN_IMAGE_FORMAT|0.1 # Minimum permissible
max: $MAX_IMAGE_FORMAT|10 # Maximum permissible
min_confidence: $MIN_CONFIDENCE|0.5 # Minimum permissible prediction confidence

4
data/base_weights.h5.dvc Normal file
View File

@ -0,0 +1,4 @@
outs:
- md5: 6d0186c1f25e889d531788f168fa6cf0
size: 16727296
path: base_weights.h5

5
data/mlruns.dvc Normal file
View File

@ -0,0 +1,5 @@
outs:
- md5: d1c708270bab6fcd344d4a8b05d1103d.dir
size: 150225383
nfiles: 178
path: mlruns

View File

@ -1,7 +1,14 @@
from pathlib import Path
from os import path
MODULE_DIR = path.dirname(path.abspath(__file__))
PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR)
REPO_ROOT_DIR = path.dirname(path.dirname(PACKAGE_ROOT_DIR))
MODULE_ROOT = Path(__file__).resolve().parents[1]
CONFIG_FILE = MODULE_ROOT / "config.yaml"
DATA_DIR = MODULE_ROOT / "data"
TORCH_HOME = DATA_DIR
DOCKER_COMPOSE_FILE = path.join(REPO_ROOT_DIR, "docker-compose.yaml")
CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
LOG_FILE = "/tmp/log.log"
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
MLRUNS_DIR = path.join(DATA_DIR, "mlruns")
BASE_WEIGHTS = path.join(DATA_DIR, "base_weights.h5")

View File

@ -1,7 +1,13 @@
import logging
from operator import itemgetter
from typing import List, Dict
import numpy as np
from image_prediction.config import CONFIG
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
class Predictor:
@ -21,9 +27,7 @@ class Predictor:
reader = MlflowModelReader(
run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR
)
# message_queue.put(text="Loading model...", level=logging.DEBUG)
self.model_handle = reader.get_model_handle(BASE_WEIGHTS)
# message_queue.put(text="Model loaded.", level=logging.DEBUG)
else:
self.model_handle = model_handle
@ -31,12 +35,7 @@ class Predictor:
self.classes_readable = np.array(self.model_handle.classes)
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
except Exception as e:
message_queue.put(
text="Service estimator initialization failed.",
exception=e,
level=logging.CRITICAL,
trace=traceback.format_exc(),
)
logging.info(f"Service estimator initialization failed: {e}")
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to

View File

@ -0,0 +1,71 @@
"""Defines functions for constructing service responses."""
from itertools import starmap
from operator import itemgetter
import numpy as np
from image_prediction.config import CONFIG
def build_response(predictions: list, metadata: list) -> list:
return list(starmap(build_image_info, zip(predictions, metadata)))
def build_image_info(prediction: dict, metadata: dict) -> dict:
def compute_geometric_quotient():
page_area_sqrt = np.sqrt(abs(page_width * page_height))
image_area_sqrt = np.sqrt(abs(x2 - x1) * abs(y2 - y1))
return image_area_sqrt / page_area_sqrt
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
"page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height"
)(metadata)
quotient = compute_geometric_quotient()
min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min)
max_image_to_page_quotient_breached = bool(quotient > CONFIG.filters.image_to_page_quotient.max)
min_image_width_to_height_quotient_breached = bool(
width / height < CONFIG.filters.image_width_to_height_quotient.min
)
max_image_width_to_height_quotient_breached = bool(
width / height > CONFIG.filters.image_width_to_height_quotient.max
)
min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence)
prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper
prediction["probabilities"] = {klass: np.round(prob, 6) for klass, prob in prediction["probabilities"].items()}
image_info = {
"classification": prediction,
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": metadata["page_idx"] + 1},
"geometry": {"width": width, "height": height},
"filters": {
"geometry": {
"imageSize": {
"quotient": quotient,
"tooLarge": max_image_to_page_quotient_breached,
"tooSmall": min_image_to_page_quotient_breached,
},
"imageFormat": {
"quotient": width / height,
"tooTall": min_image_width_to_height_quotient_breached,
"tooWide": max_image_width_to_height_quotient_breached,
},
},
"probability": {"unconfident": min_confidence_breached},
"allPassed": not any(
[
max_image_to_page_quotient_breached,
min_image_to_page_quotient_breached,
min_image_width_to_height_quotient_breached,
max_image_width_to_height_quotient_breached,
min_confidence_breached,
]
),
},
}
return image_info

View File

@ -1,32 +0,0 @@
import os
from image_prediction.config import CONFIG
from image_prediction.locations import DATA_DIR, TORCH_HOME
from image_prediction.predictor import Predictor
def suppress_userwarnings():
import warnings
warnings.filterwarnings("ignore")
def load_classes():
classes = CONFIG.estimator.classes
id2class = dict(zip(range(1, len(classes) + 1), classes))
return id2class
def get_checkpoint():
return DATA_DIR / CONFIG.estimator.checkpoint
def set_torch_env():
os.environ["TORCH_HOME"] = str(TORCH_HOME)
def initialize_predictor(resume):
set_torch_env()
checkpoint = get_checkpoint() if not resume else resume
predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=CONFIG.estimator.rejection_class)
return predictor

1
incl/redai_image Submodule

@ -0,0 +1 @@
Subproject commit 4c3b26d7673457aaa99e0663dad6950cd36da967

View File

@ -5,9 +5,9 @@ python3 -m venv build_venv
source build_venv/bin/activate
python3 -m pip install --upgrade pip
#pip install dvc
#pip install 'dvc[ssh]'
#dvc pull
pip install dvc
pip install 'dvc[ssh]'
dvc pull
git submodule update --init --recursive

View File

@ -1,17 +1,29 @@
import argparse
import json
import logging
from typing import Callable
import tempfile
from itertools import chain
from operator import itemgetter
from typing import Iterable
from flask import Flask, request, jsonify
from waitress import serve
from image_prediction.config import CONFIG
from image_prediction.utils.estimator import suppress_userwarnings, initialize_predictor
from image_prediction.predictor import Predictor
from image_prediction.response import build_response
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
def suppress_userwarnings():
import warnings
warnings.filterwarnings("ignore")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--resume")
parser.add_argument("--warnings", action="store_true", default=False)
args = parser.parse_args()
@ -22,16 +34,9 @@ def main(args):
if not args.warnings:
suppress_userwarnings()
predictor = initialize_predictor(args.resume)
predictor = Predictor()
logging.info("Predictor ready.")
prediction_server = make_prediction_server(predictor.predict_pdf)
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
def make_prediction_server(predict_fn: Callable):
app = Flask(__name__)
@app.route("/ready", methods=["GET"])
@ -48,46 +53,66 @@ def make_prediction_server(predict_fn: Callable):
@app.route("/", methods=["POST"])
def predict():
def __predict():
def inner():
pdf = request.data
logging.debug("Running predictor on document...")
predictions = predict_fn(pdf)
logging.debug(f"Found {len(predictions)} images in document.")
response = jsonify(list(predictions))
pdf = request.data
logging.debug("Running predictor on document...")
# extract images from pdfs
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(pdf)
image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name)
try:
predictions, metadata = classify_images(predictor, image_metadata_pairs)
except Exception as err:
logging.warning("Analysis failed.")
logging.exception(err)
response = jsonify("Analysis failed.")
response.status_code = 500
return response
logging.debug(f"Found images in document.")
logging.info(f"Analyzing...")
result = inner()
logging.info("Analysis completed.")
return result
response = jsonify(build_response(list(predictions), list(metadata)))
try:
return __predict()
except Exception as err:
logging.warning("Analysis failed.")
logging.exception(err)
response = jsonify("Analysis failed.")
response.status_code = 500
return response
logging.info("Analysis completed.")
return response
return app
run_prediction_server(app, mode=CONFIG.webserver.mode)
def run_prediction_server(app, mode="development"):
if mode == "development":
app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True)
elif mode == "production":
serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port)
if __name__ == "__main__":
def extract_image_metadata_pairs(pdf_path: str, **kwargs):
def image_is_large_enough(metadata: dict):
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
def process_chunk(chunk):
images, metadata = zip(*chunk)
predictions = predictor.predict(images, probabilities=True)
return predictions, metadata
def predict(image_metadata_pair_generator):
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
try:
predictions, metadata = predict(image_metadata_pairs)
return predictions, metadata
except ValueError:
return [], []
if __name__ == "__main__":
logging_level = CONFIG.service.logging_level
logging.basicConfig(level=logging_level)
logging.getLogger("flask").setLevel(logging.ERROR)