RED-3501: adapt service-container to image-service-v2
This commit is contained in:
parent
684aca364f
commit
42ae5793e0
2
.gitignore
vendored
2
.gitignore
vendored
@ -171,3 +171,5 @@ fabric.properties
|
||||
.idea/codestream.xml
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/linux,pycharm
|
||||
/image_prediction/data/mlruns/
|
||||
/data/mlruns/
|
||||
|
||||
2
.gitmodules
vendored
2
.gitmodules
vendored
@ -1,3 +1,3 @@
|
||||
[submodule "incl/redai_image"]
|
||||
path = incl/redai_image
|
||||
url = ssh://git@git.iqser.com:2222/rr/redai_image.git
|
||||
url = ssh://git@git.iqser.com:2222/rr/redai_image.git
|
||||
|
||||
@ -11,13 +11,14 @@ COPY image_prediction ./image_prediction
|
||||
COPY ./setup.py ./setup.py
|
||||
COPY ./requirements.txt ./requirements.txt
|
||||
COPY ./config.yaml ./config.yaml
|
||||
COPY data data
|
||||
|
||||
# Install dependencies differing from base image.
|
||||
RUN python3 -m pip install -r requirements.txt
|
||||
|
||||
RUN python3 -m pip install -e .
|
||||
|
||||
WORKDIR /app/service/incl/redai_image
|
||||
WORKDIR /app/service/incl/redai_image/redai
|
||||
RUN python3 -m pip install -e .
|
||||
WORKDIR /app/service
|
||||
|
||||
|
||||
@ -10,11 +10,13 @@ RUN python -m pip install --upgrade pip
|
||||
# Make a directory for the service files and copy the service repo into the container.
|
||||
WORKDIR /app/service
|
||||
COPY ./requirements.txt ./requirements.txt
|
||||
COPY ./data ./data
|
||||
COPY ./incl/redai_image/redai/requirements_user.txt ./requirements_redai.txt
|
||||
|
||||
# Install dependencies.
|
||||
RUN python3 -m pip install -r requirements.txt
|
||||
|
||||
RUN python3 -m pip install -r requirements_redai.txt
|
||||
|
||||
# Make a new container and copy all relevant files over to filter out temporary files
|
||||
# produced during setup to reduce the final container's size.
|
||||
FROM python:3.8
|
||||
|
||||
26
config.yaml
26
config.yaml
@ -1,10 +1,3 @@
|
||||
estimator:
|
||||
checkpoint: checkpoint.pth
|
||||
classes: ["logo", "other", "formula", "signature", "handwriting_other"]
|
||||
rejection_class: "other"
|
||||
threshold: .5
|
||||
device: cpu
|
||||
|
||||
webserver:
|
||||
host: $SERVER_HOST|"127.0.0.1" # webserver address
|
||||
port: $SERVER_PORT|5000 # webserver port
|
||||
@ -14,3 +7,22 @@ service:
|
||||
logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger
|
||||
batch_size: $BATCH_SIZE|2 # Number of images in memory simultaneously
|
||||
verbose: $VERBOSE|True # Service prints document processing progress to stdout
|
||||
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the model from
|
||||
|
||||
|
||||
# These variables control filters that are applied to either images, image metadata or model predictions. The filter
|
||||
# result values are reported in the service responses. For convenience the response to a request contains a
|
||||
# "filters.allPassed" field, which is set to false if any of the filters returned values did not meet its specified
|
||||
# required value.
|
||||
filters:
|
||||
|
||||
image_to_page_quotient: # Image size to page size ratio (ratio of geometric means of areas)
|
||||
min: $MIN_REL_IMAGE_SIZE|0.05 # Minimum permissible
|
||||
max: $MAX_REL_IMAGE_SIZE|0.75 # Maximum permissible
|
||||
|
||||
image_width_to_height_quotient: # Image width to height ratio
|
||||
min: $MIN_IMAGE_FORMAT|0.1 # Minimum permissible
|
||||
max: $MAX_IMAGE_FORMAT|10 # Maximum permissible
|
||||
|
||||
min_confidence: $MIN_CONFIDENCE|0.5 # Minimum permissible prediction confidence
|
||||
|
||||
|
||||
4
data/base_weights.h5.dvc
Normal file
4
data/base_weights.h5.dvc
Normal file
@ -0,0 +1,4 @@
|
||||
outs:
|
||||
- md5: 6d0186c1f25e889d531788f168fa6cf0
|
||||
size: 16727296
|
||||
path: base_weights.h5
|
||||
5
data/mlruns.dvc
Normal file
5
data/mlruns.dvc
Normal file
@ -0,0 +1,5 @@
|
||||
outs:
|
||||
- md5: d1c708270bab6fcd344d4a8b05d1103d.dir
|
||||
size: 150225383
|
||||
nfiles: 178
|
||||
path: mlruns
|
||||
@ -1,7 +1,14 @@
|
||||
from pathlib import Path
|
||||
from os import path
|
||||
|
||||
MODULE_DIR = path.dirname(path.abspath(__file__))
|
||||
PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR)
|
||||
REPO_ROOT_DIR = path.dirname(path.dirname(PACKAGE_ROOT_DIR))
|
||||
|
||||
MODULE_ROOT = Path(__file__).resolve().parents[1]
|
||||
CONFIG_FILE = MODULE_ROOT / "config.yaml"
|
||||
DATA_DIR = MODULE_ROOT / "data"
|
||||
TORCH_HOME = DATA_DIR
|
||||
DOCKER_COMPOSE_FILE = path.join(REPO_ROOT_DIR, "docker-compose.yaml")
|
||||
|
||||
CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
|
||||
LOG_FILE = "/tmp/log.log"
|
||||
|
||||
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
|
||||
MLRUNS_DIR = path.join(DATA_DIR, "mlruns")
|
||||
BASE_WEIGHTS = path.join(DATA_DIR, "base_weights.h5")
|
||||
|
||||
@ -1,7 +1,13 @@
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
|
||||
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
|
||||
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
|
||||
|
||||
|
||||
class Predictor:
|
||||
@ -21,9 +27,7 @@ class Predictor:
|
||||
reader = MlflowModelReader(
|
||||
run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR
|
||||
)
|
||||
# message_queue.put(text="Loading model...", level=logging.DEBUG)
|
||||
self.model_handle = reader.get_model_handle(BASE_WEIGHTS)
|
||||
# message_queue.put(text="Model loaded.", level=logging.DEBUG)
|
||||
else:
|
||||
self.model_handle = model_handle
|
||||
|
||||
@ -31,12 +35,7 @@ class Predictor:
|
||||
self.classes_readable = np.array(self.model_handle.classes)
|
||||
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
|
||||
except Exception as e:
|
||||
message_queue.put(
|
||||
text="Service estimator initialization failed.",
|
||||
exception=e,
|
||||
level=logging.CRITICAL,
|
||||
trace=traceback.format_exc(),
|
||||
)
|
||||
logging.info(f"Service estimator initialization failed: {e}")
|
||||
|
||||
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
|
||||
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to
|
||||
|
||||
71
image_prediction/response.py
Normal file
71
image_prediction/response.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Defines functions for constructing service responses."""
|
||||
|
||||
|
||||
from itertools import starmap
|
||||
from operator import itemgetter
|
||||
|
||||
import numpy as np
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
|
||||
|
||||
def build_response(predictions: list, metadata: list) -> list:
|
||||
return list(starmap(build_image_info, zip(predictions, metadata)))
|
||||
|
||||
|
||||
def build_image_info(prediction: dict, metadata: dict) -> dict:
|
||||
def compute_geometric_quotient():
|
||||
page_area_sqrt = np.sqrt(abs(page_width * page_height))
|
||||
image_area_sqrt = np.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
||||
return image_area_sqrt / page_area_sqrt
|
||||
|
||||
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
|
||||
"page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height"
|
||||
)(metadata)
|
||||
|
||||
quotient = compute_geometric_quotient()
|
||||
|
||||
min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min)
|
||||
max_image_to_page_quotient_breached = bool(quotient > CONFIG.filters.image_to_page_quotient.max)
|
||||
min_image_width_to_height_quotient_breached = bool(
|
||||
width / height < CONFIG.filters.image_width_to_height_quotient.min
|
||||
)
|
||||
max_image_width_to_height_quotient_breached = bool(
|
||||
width / height > CONFIG.filters.image_width_to_height_quotient.max
|
||||
)
|
||||
|
||||
min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||
prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper
|
||||
prediction["probabilities"] = {klass: np.round(prob, 6) for klass, prob in prediction["probabilities"].items()}
|
||||
|
||||
image_info = {
|
||||
"classification": prediction,
|
||||
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": metadata["page_idx"] + 1},
|
||||
"geometry": {"width": width, "height": height},
|
||||
"filters": {
|
||||
"geometry": {
|
||||
"imageSize": {
|
||||
"quotient": quotient,
|
||||
"tooLarge": max_image_to_page_quotient_breached,
|
||||
"tooSmall": min_image_to_page_quotient_breached,
|
||||
},
|
||||
"imageFormat": {
|
||||
"quotient": width / height,
|
||||
"tooTall": min_image_width_to_height_quotient_breached,
|
||||
"tooWide": max_image_width_to_height_quotient_breached,
|
||||
},
|
||||
},
|
||||
"probability": {"unconfident": min_confidence_breached},
|
||||
"allPassed": not any(
|
||||
[
|
||||
max_image_to_page_quotient_breached,
|
||||
min_image_to_page_quotient_breached,
|
||||
min_image_width_to_height_quotient_breached,
|
||||
max_image_width_to_height_quotient_breached,
|
||||
min_confidence_breached,
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
return image_info
|
||||
@ -1,32 +0,0 @@
|
||||
import os
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.locations import DATA_DIR, TORCH_HOME
|
||||
from image_prediction.predictor import Predictor
|
||||
|
||||
|
||||
def suppress_userwarnings():
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def load_classes():
|
||||
classes = CONFIG.estimator.classes
|
||||
id2class = dict(zip(range(1, len(classes) + 1), classes))
|
||||
return id2class
|
||||
|
||||
|
||||
def get_checkpoint():
|
||||
return DATA_DIR / CONFIG.estimator.checkpoint
|
||||
|
||||
|
||||
def set_torch_env():
|
||||
os.environ["TORCH_HOME"] = str(TORCH_HOME)
|
||||
|
||||
|
||||
def initialize_predictor(resume):
|
||||
set_torch_env()
|
||||
checkpoint = get_checkpoint() if not resume else resume
|
||||
predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=CONFIG.estimator.rejection_class)
|
||||
return predictor
|
||||
1
incl/redai_image
Submodule
1
incl/redai_image
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 4c3b26d7673457aaa99e0663dad6950cd36da967
|
||||
@ -5,9 +5,9 @@ python3 -m venv build_venv
|
||||
source build_venv/bin/activate
|
||||
python3 -m pip install --upgrade pip
|
||||
|
||||
#pip install dvc
|
||||
#pip install 'dvc[ssh]'
|
||||
#dvc pull
|
||||
pip install dvc
|
||||
pip install 'dvc[ssh]'
|
||||
dvc pull
|
||||
|
||||
git submodule update --init --recursive
|
||||
|
||||
|
||||
97
src/serve.py
97
src/serve.py
@ -1,17 +1,29 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from typing import Callable
|
||||
import tempfile
|
||||
from itertools import chain
|
||||
from operator import itemgetter
|
||||
from typing import Iterable
|
||||
|
||||
from flask import Flask, request, jsonify
|
||||
from waitress import serve
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.utils.estimator import suppress_userwarnings, initialize_predictor
|
||||
from image_prediction.predictor import Predictor
|
||||
from image_prediction.response import build_response
|
||||
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
||||
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
|
||||
|
||||
|
||||
def suppress_userwarnings():
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--resume")
|
||||
parser.add_argument("--warnings", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -22,16 +34,9 @@ def main(args):
|
||||
if not args.warnings:
|
||||
suppress_userwarnings()
|
||||
|
||||
predictor = initialize_predictor(args.resume)
|
||||
predictor = Predictor()
|
||||
logging.info("Predictor ready.")
|
||||
|
||||
prediction_server = make_prediction_server(predictor.predict_pdf)
|
||||
|
||||
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
|
||||
|
||||
|
||||
def make_prediction_server(predict_fn: Callable):
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/ready", methods=["GET"])
|
||||
@ -48,46 +53,66 @@ def make_prediction_server(predict_fn: Callable):
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def predict():
|
||||
def __predict():
|
||||
|
||||
def inner():
|
||||
|
||||
pdf = request.data
|
||||
|
||||
logging.debug("Running predictor on document...")
|
||||
predictions = predict_fn(pdf)
|
||||
logging.debug(f"Found {len(predictions)} images in document.")
|
||||
response = jsonify(list(predictions))
|
||||
pdf = request.data
|
||||
|
||||
logging.debug("Running predictor on document...")
|
||||
# extract images from pdfs
|
||||
with tempfile.NamedTemporaryFile() as tmp_file:
|
||||
tmp_file.write(pdf)
|
||||
image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name)
|
||||
try:
|
||||
predictions, metadata = classify_images(predictor, image_metadata_pairs)
|
||||
except Exception as err:
|
||||
logging.warning("Analysis failed.")
|
||||
logging.exception(err)
|
||||
response = jsonify("Analysis failed.")
|
||||
response.status_code = 500
|
||||
return response
|
||||
logging.debug(f"Found images in document.")
|
||||
|
||||
logging.info(f"Analyzing...")
|
||||
result = inner()
|
||||
logging.info("Analysis completed.")
|
||||
return result
|
||||
response = jsonify(build_response(list(predictions), list(metadata)))
|
||||
|
||||
try:
|
||||
return __predict()
|
||||
except Exception as err:
|
||||
logging.warning("Analysis failed.")
|
||||
logging.exception(err)
|
||||
response = jsonify("Analysis failed.")
|
||||
response.status_code = 500
|
||||
return response
|
||||
logging.info("Analysis completed.")
|
||||
return response
|
||||
|
||||
return app
|
||||
run_prediction_server(app, mode=CONFIG.webserver.mode)
|
||||
|
||||
|
||||
def run_prediction_server(app, mode="development"):
|
||||
|
||||
if mode == "development":
|
||||
app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True)
|
||||
elif mode == "production":
|
||||
serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
||||
def image_is_large_enough(metadata: dict):
|
||||
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
||||
|
||||
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
||||
|
||||
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
||||
|
||||
|
||||
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
||||
def process_chunk(chunk):
|
||||
images, metadata = zip(*chunk)
|
||||
predictions = predictor.predict(images, probabilities=True)
|
||||
return predictions, metadata
|
||||
|
||||
def predict(image_metadata_pair_generator):
|
||||
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
||||
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
||||
|
||||
try:
|
||||
predictions, metadata = predict(image_metadata_pairs)
|
||||
return predictions, metadata
|
||||
|
||||
except ValueError:
|
||||
return [], []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging_level = CONFIG.service.logging_level
|
||||
logging.basicConfig(level=logging_level)
|
||||
logging.getLogger("flask").setLevel(logging.ERROR)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user