Pull request #3: Refactoring

Merge in RR/image-prediction from refactoring to master

Squashed commit of the following:

commit fc4e2efac113f2e307fdbc091e0a4f4e3e5729d3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Mar 16 14:21:05 2022 +0100

    applied black

commit 3baabf5bc0b04347af85dafbb056f134258d9715
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Mar 16 14:20:30 2022 +0100

    added banner

commit 30e871cfdc79d0ff2e0c26d1b858e55ab1b0453f
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Mar 16 14:02:26 2022 +0100

    rename logger

commit d76fefd3ff0c4425defca4db218ce4a84c6053f3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Mar 16 14:00:39 2022 +0100

    logger refactoring

commit 0e004cbd21ab00b8804901952405fa870bf48e9c
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Mar 16 14:00:08 2022 +0100

    logger refactoring

commit 49e113f8d85d7973b73f664779906a1347d1522d
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Mar 16 13:25:08 2022 +0100

    refactoring

commit 7ec3d52e155cb83bed8804d2fee4f5bdf54fb59b
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Mar 16 13:21:52 2022 +0100

    applied black

commit 06ea0be8aa9344e11b9d92fd526f2b73061bc736
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date:   Wed Mar 16 13:21:20 2022 +0100

    refactoring
This commit is contained in:
Matthias Bisping 2022-03-16 15:07:30 +01:00 committed by Julius Unverfehrt
parent 4d95b84f2f
commit a9d60654f5
7 changed files with 161 additions and 73 deletions

View File

@ -5,6 +5,7 @@ webserver:
service: service:
logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger
progressbar: True # Whether a progress bar over the pages of a document is displayed while processing
batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously
verbose: $VERBOSE|True # Service prints document processing progress to stdout verbose: $VERBOSE|True # Service prints document processing progress to stdout
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the model from run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the model from
@ -25,4 +26,3 @@ filters:
max: $MAX_IMAGE_FORMAT|10 # Maximum permissible max: $MAX_IMAGE_FORMAT|10 # Maximum permissible
min_confidence: $MIN_CONFIDENCE|0.5 # Minimum permissible prediction confidence min_confidence: $MIN_CONFIDENCE|0.5 # Minimum permissible prediction confidence

43
image_prediction/flask.py Normal file
View File

@ -0,0 +1,43 @@
from typing import Callable
from flask import Flask, request, jsonify
from image_prediction.utils import get_logger
logger = get_logger()
def make_prediction_server(predict_fn: Callable):
app = Flask(__name__)
@app.route("/ready", methods=["GET"])
def ready():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/health", methods=["GET"])
def healthy():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/", methods=["POST"])
def predict():
pdf = request.data
logger.debug("Running predictor on document...")
try:
predictions = predict_fn(pdf)
response = jsonify(predictions)
logger.info("Analysis completed.")
return response
except Exception as err:
logger.error("Analysis failed.")
logger.exception(err)
response = jsonify("Analysis failed.")
response.status_code = 500
return response
return app

View File

@ -2,12 +2,8 @@ from os import path
MODULE_DIR = path.dirname(path.abspath(__file__)) MODULE_DIR = path.dirname(path.abspath(__file__))
PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR) PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR)
REPO_ROOT_DIR = path.dirname(path.dirname(PACKAGE_ROOT_DIR))
DOCKER_COMPOSE_FILE = path.join(REPO_ROOT_DIR, "docker-compose.yaml")
CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml") CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
LOG_FILE = "/tmp/log.log"
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data") DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
MLRUNS_DIR = path.join(DATA_DIR, "mlruns") MLRUNS_DIR = path.join(DATA_DIR, "mlruns")

View File

@ -1,4 +1,3 @@
import logging
from itertools import chain from itertools import chain
from operator import itemgetter from operator import itemgetter
from typing import List, Dict, Iterable from typing import List, Dict, Iterable
@ -7,11 +6,14 @@ import numpy as np
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
from image_prediction.utils import temporary_pdf_file, get_logger
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
from incl.redai_image.redai.redai.utils.shared import chunk_iterable from incl.redai_image.redai.redai.utils.shared import chunk_iterable
logger = get_logger()
class Predictor: class Predictor:
"""`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is """`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is
@ -36,7 +38,7 @@ class Predictor:
self.classes_readable = np.array(self.model_handle.classes) self.classes_readable = np.array(self.model_handle.classes)
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]] self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
except Exception as e: except Exception as e:
logging.info(f"Service estimator initialization failed: {e}") logger.info(f"Service estimator initialization failed: {e}")
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]: def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to """Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to
@ -88,20 +90,15 @@ class Predictor:
return predictions if probabilities else classes return predictions if probabilities else classes
def predict_pdf(self, pdf, verbose=False):
with temporary_pdf_file(pdf) as pdf_path:
image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path, verbose=verbose)
return self.__predict_images(image_metadata_pairs)
def extract_image_metadata_pairs(pdf_path: str, **kwargs): def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
def image_is_large_enough(metadata: dict):
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
def process_chunk(chunk): def process_chunk(chunk):
images, metadata = zip(*chunk) images, metadata = zip(*chunk)
predictions = predictor.predict(images, probabilities=True) predictions = self.predict(images, probabilities=True)
return predictions, metadata return predictions, metadata
def predict(image_metadata_pair_generator): def predict(image_metadata_pair_generator):
@ -114,3 +111,12 @@ def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int =
except ValueError: except ValueError:
return [], [] return [], []
@staticmethod
def __extract_image_metadata_pairs(pdf_path: str, **kwargs):
def image_is_large_enough(metadata: dict):
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)

View File

@ -1,11 +1,10 @@
"""Defines functions for constructing service responses.""" """Defines functions for constructing service responses."""
import math
from itertools import starmap from itertools import starmap
from operator import itemgetter from operator import itemgetter
import numpy as np
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
@ -15,8 +14,8 @@ def build_response(predictions: list, metadata: list) -> list:
def build_image_info(prediction: dict, metadata: dict) -> dict: def build_image_info(prediction: dict, metadata: dict) -> dict:
def compute_geometric_quotient(): def compute_geometric_quotient():
page_area_sqrt = np.sqrt(abs(page_width * page_height)) page_area_sqrt = math.sqrt(abs(page_width * page_height))
image_area_sqrt = np.sqrt(abs(x2 - x1) * abs(y2 - y1)) image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
return image_area_sqrt / page_area_sqrt return image_area_sqrt / page_area_sqrt
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter( page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
@ -36,7 +35,7 @@ def build_image_info(prediction: dict, metadata: dict) -> dict:
min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence) min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence)
prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper
prediction["probabilities"] = {klass: np.round(prob, 6) for klass, prob in prediction["probabilities"].items()} prediction["probabilities"] = {klass: round(prob, 6) for klass, prob in prediction["probabilities"].items()}
image_info = { image_info = {
"classification": prediction, "classification": prediction,

68
image_prediction/utils.py Normal file
View File

@ -0,0 +1,68 @@
import logging
import tempfile
from contextlib import contextmanager
from image_prediction.config import CONFIG
@contextmanager
def temporary_pdf_file(pdf: bytes):
with tempfile.NamedTemporaryFile() as f:
f.write(pdf)
yield f.name
def make_logger_getter():
logger = logging.getLogger("imclf")
logger.propagate = False
handler = logging.StreamHandler()
handler.setLevel(CONFIG.service.logging_level)
log_format = "[%(levelname)s]: %(message)s"
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
logger.addHandler(handler)
def get_logger():
return logger
return get_logger
get_logger = make_logger_getter()
def show_banner():
banner = '''
..... . ... ..
.d88888Neu. 'L xH88"`~ .x8X x .d88" oec :
F""""*8888888F .. . : :8888 .f"8888Hf 5888R @88888
* `"*88*" .888: x888 x888. :8888> X8L ^""` '888R 8"*88%
-.... ue=:. ~`8888~'888X`?888f` X8888 X888h 888R 8b.
:88N ` X888 888X '888> 88888 !88888. 888R u888888>
9888L X888 888X '888> 88888 %88888 888R 8888R
uzu. `8888L X888 888X '888> 88888 '> `8888> 888R 8888P
,""888i ?8888 X888 888X '888> `8888L % ?888 ! 888R *888>
4 9888L %888> "*88%""*88" '888!` `8888 `-*"" / .888B . 4888
' '8888 '88% `~ " `"` "888. :" ^*888% '888
"*8Nu.z*" `""***~"` "% 88R
88>
48
'8
'''
logger = logging.getLogger(__name__)
logger.propagate = False
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter("")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.info(banner)

View File

@ -1,57 +1,29 @@
import logging import logging
import tempfile
from flask import Flask, request, jsonify
from waitress import serve from waitress import serve
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.predictor import Predictor, extract_image_metadata_pairs, classify_images from image_prediction.flask import make_prediction_server
from image_prediction.predictor import Predictor
from image_prediction.response import build_response from image_prediction.response import build_response
from image_prediction.utils import get_logger, show_banner
logger = get_logger()
def main(): def main():
def predict(pdf):
predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar)
response = build_response(predictions, metadata)
return response
predictor = Predictor() predictor = Predictor()
logging.info("Predictor ready.") logger.info("Predictor ready.")
app = Flask(__name__) prediction_server = make_prediction_server(predict)
@app.route("/ready", methods=["GET"]) run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
def ready():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/health", methods=["GET"])
def healthy():
resp = jsonify("OK")
resp.status_code = 200
return resp
@app.route("/", methods=["POST"])
def predict():
pdf = request.data
logging.debug("Running predictor on document...")
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(pdf)
image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name)
try:
predictions, metadata = classify_images(predictor, image_metadata_pairs)
except Exception as err:
logging.warning("Analysis failed.")
logging.exception(err)
response = jsonify("Analysis failed.")
response.status_code = 500
return response
logging.debug(f"Found images in document.")
response = jsonify(build_response(list(predictions), list(metadata)))
logging.info("Analysis completed.")
return response
run_prediction_server(app, mode=CONFIG.webserver.mode)
def run_prediction_server(app, mode="development"): def run_prediction_server(app, mode="development"):
@ -68,5 +40,9 @@ if __name__ == "__main__":
logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("werkzeug").setLevel(logging.ERROR) logging.getLogger("werkzeug").setLevel(logging.ERROR)
logging.getLogger("waitress").setLevel(logging.ERROR) logging.getLogger("waitress").setLevel(logging.ERROR)
logging.getLogger("PIL").setLevel(logging.ERROR)
logging.getLogger("h5py").setLevel(logging.ERROR)
show_banner()
main() main()