Pull request #3: Refactoring
Merge in RR/image-prediction from refactoring to master
Squashed commit of the following:
commit fc4e2efac113f2e307fdbc091e0a4f4e3e5729d3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:21:05 2022 +0100
applied black
commit 3baabf5bc0b04347af85dafbb056f134258d9715
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:20:30 2022 +0100
added banner
commit 30e871cfdc79d0ff2e0c26d1b858e55ab1b0453f
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:02:26 2022 +0100
rename logger
commit d76fefd3ff0c4425defca4db218ce4a84c6053f3
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:00:39 2022 +0100
logger refactoring
commit 0e004cbd21ab00b8804901952405fa870bf48e9c
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 14:00:08 2022 +0100
logger refactoring
commit 49e113f8d85d7973b73f664779906a1347d1522d
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 13:25:08 2022 +0100
refactoring
commit 7ec3d52e155cb83bed8804d2fee4f5bdf54fb59b
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 13:21:52 2022 +0100
applied black
commit 06ea0be8aa9344e11b9d92fd526f2b73061bc736
Author: Matthias Bisping <matthias.bisping@iqser.com>
Date: Wed Mar 16 13:21:20 2022 +0100
refactoring
This commit is contained in:
parent
4d95b84f2f
commit
a9d60654f5
@ -5,6 +5,7 @@ webserver:
|
|||||||
|
|
||||||
service:
|
service:
|
||||||
logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger
|
logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger
|
||||||
|
progressbar: True # Whether a progress bar over the pages of a document is displayed while processing
|
||||||
batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously
|
batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously
|
||||||
verbose: $VERBOSE|True # Service prints document processing progress to stdout
|
verbose: $VERBOSE|True # Service prints document processing progress to stdout
|
||||||
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the model from
|
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the model from
|
||||||
@ -25,4 +26,3 @@ filters:
|
|||||||
max: $MAX_IMAGE_FORMAT|10 # Maximum permissible
|
max: $MAX_IMAGE_FORMAT|10 # Maximum permissible
|
||||||
|
|
||||||
min_confidence: $MIN_CONFIDENCE|0.5 # Minimum permissible prediction confidence
|
min_confidence: $MIN_CONFIDENCE|0.5 # Minimum permissible prediction confidence
|
||||||
|
|
||||||
|
|||||||
43
image_prediction/flask.py
Normal file
43
image_prediction/flask.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from flask import Flask, request, jsonify
|
||||||
|
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def make_prediction_server(predict_fn: Callable):
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
@app.route("/ready", methods=["GET"])
|
||||||
|
def ready():
|
||||||
|
resp = jsonify("OK")
|
||||||
|
resp.status_code = 200
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@app.route("/health", methods=["GET"])
|
||||||
|
def healthy():
|
||||||
|
resp = jsonify("OK")
|
||||||
|
resp.status_code = 200
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@app.route("/", methods=["POST"])
|
||||||
|
def predict():
|
||||||
|
pdf = request.data
|
||||||
|
|
||||||
|
logger.debug("Running predictor on document...")
|
||||||
|
try:
|
||||||
|
predictions = predict_fn(pdf)
|
||||||
|
response = jsonify(predictions)
|
||||||
|
logger.info("Analysis completed.")
|
||||||
|
return response
|
||||||
|
except Exception as err:
|
||||||
|
logger.error("Analysis failed.")
|
||||||
|
logger.exception(err)
|
||||||
|
response = jsonify("Analysis failed.")
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
|
|
||||||
|
return app
|
||||||
@ -2,12 +2,8 @@ from os import path
|
|||||||
|
|
||||||
MODULE_DIR = path.dirname(path.abspath(__file__))
|
MODULE_DIR = path.dirname(path.abspath(__file__))
|
||||||
PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR)
|
PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR)
|
||||||
REPO_ROOT_DIR = path.dirname(path.dirname(PACKAGE_ROOT_DIR))
|
|
||||||
|
|
||||||
DOCKER_COMPOSE_FILE = path.join(REPO_ROOT_DIR, "docker-compose.yaml")
|
|
||||||
|
|
||||||
CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
|
CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
|
||||||
LOG_FILE = "/tmp/log.log"
|
|
||||||
|
|
||||||
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
|
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
|
||||||
MLRUNS_DIR = path.join(DATA_DIR, "mlruns")
|
MLRUNS_DIR = path.join(DATA_DIR, "mlruns")
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import List, Dict, Iterable
|
from typing import List, Dict, Iterable
|
||||||
@ -7,11 +6,14 @@ import numpy as np
|
|||||||
|
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
|
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
|
||||||
|
from image_prediction.utils import temporary_pdf_file, get_logger
|
||||||
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
|
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
|
||||||
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
||||||
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
|
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
|
||||||
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
|
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class Predictor:
|
class Predictor:
|
||||||
"""`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is
|
"""`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is
|
||||||
@ -36,7 +38,7 @@ class Predictor:
|
|||||||
self.classes_readable = np.array(self.model_handle.classes)
|
self.classes_readable = np.array(self.model_handle.classes)
|
||||||
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
|
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"Service estimator initialization failed: {e}")
|
logger.info(f"Service estimator initialization failed: {e}")
|
||||||
|
|
||||||
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
|
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
|
||||||
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to
|
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to
|
||||||
@ -88,29 +90,33 @@ class Predictor:
|
|||||||
|
|
||||||
return predictions if probabilities else classes
|
return predictions if probabilities else classes
|
||||||
|
|
||||||
|
def predict_pdf(self, pdf, verbose=False):
|
||||||
|
with temporary_pdf_file(pdf) as pdf_path:
|
||||||
|
image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path, verbose=verbose)
|
||||||
|
return self.__predict_images(image_metadata_pairs)
|
||||||
|
|
||||||
def extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
||||||
def image_is_large_enough(metadata: dict):
|
def process_chunk(chunk):
|
||||||
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
images, metadata = zip(*chunk)
|
||||||
|
predictions = self.predict(images, probabilities=True)
|
||||||
|
return predictions, metadata
|
||||||
|
|
||||||
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
def predict(image_metadata_pair_generator):
|
||||||
|
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
||||||
|
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
||||||
|
|
||||||
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
try:
|
||||||
|
predictions, metadata = predict(image_metadata_pairs)
|
||||||
|
return predictions, metadata
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
return [], []
|
||||||
|
|
||||||
def classify_images(predictor, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
@staticmethod
|
||||||
def process_chunk(chunk):
|
def __extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
||||||
images, metadata = zip(*chunk)
|
def image_is_large_enough(metadata: dict):
|
||||||
predictions = predictor.predict(images, probabilities=True)
|
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
||||||
return predictions, metadata
|
|
||||||
|
|
||||||
def predict(image_metadata_pair_generator):
|
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
||||||
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
|
||||||
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
|
||||||
|
|
||||||
try:
|
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
||||||
predictions, metadata = predict(image_metadata_pairs)
|
|
||||||
return predictions, metadata
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
return [], []
|
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
"""Defines functions for constructing service responses."""
|
"""Defines functions for constructing service responses."""
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
from itertools import starmap
|
from itertools import starmap
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
|
|
||||||
|
|
||||||
@ -15,8 +14,8 @@ def build_response(predictions: list, metadata: list) -> list:
|
|||||||
|
|
||||||
def build_image_info(prediction: dict, metadata: dict) -> dict:
|
def build_image_info(prediction: dict, metadata: dict) -> dict:
|
||||||
def compute_geometric_quotient():
|
def compute_geometric_quotient():
|
||||||
page_area_sqrt = np.sqrt(abs(page_width * page_height))
|
page_area_sqrt = math.sqrt(abs(page_width * page_height))
|
||||||
image_area_sqrt = np.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
||||||
return image_area_sqrt / page_area_sqrt
|
return image_area_sqrt / page_area_sqrt
|
||||||
|
|
||||||
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
|
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
|
||||||
@ -36,7 +35,7 @@ def build_image_info(prediction: dict, metadata: dict) -> dict:
|
|||||||
|
|
||||||
min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence)
|
min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||||
prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper
|
prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper
|
||||||
prediction["probabilities"] = {klass: np.round(prob, 6) for klass, prob in prediction["probabilities"].items()}
|
prediction["probabilities"] = {klass: round(prob, 6) for klass, prob in prediction["probabilities"].items()}
|
||||||
|
|
||||||
image_info = {
|
image_info = {
|
||||||
"classification": prediction,
|
"classification": prediction,
|
||||||
|
|||||||
68
image_prediction/utils.py
Normal file
68
image_prediction/utils.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from image_prediction.config import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def temporary_pdf_file(pdf: bytes):
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
f.write(pdf)
|
||||||
|
yield f.name
|
||||||
|
|
||||||
|
|
||||||
|
def make_logger_getter():
|
||||||
|
|
||||||
|
logger = logging.getLogger("imclf")
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(CONFIG.service.logging_level)
|
||||||
|
|
||||||
|
log_format = "[%(levelname)s]: %(message)s"
|
||||||
|
formatter = logging.Formatter(log_format)
|
||||||
|
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def get_logger():
|
||||||
|
return logger
|
||||||
|
|
||||||
|
return get_logger
|
||||||
|
|
||||||
|
|
||||||
|
get_logger = make_logger_getter()
|
||||||
|
|
||||||
|
|
||||||
|
def show_banner():
|
||||||
|
banner = '''
|
||||||
|
..... . ... ..
|
||||||
|
.d88888Neu. 'L xH88"`~ .x8X x .d88" oec :
|
||||||
|
F""""*8888888F .. . : :8888 .f"8888Hf 5888R @88888
|
||||||
|
* `"*88*" .888: x888 x888. :8888> X8L ^""` '888R 8"*88%
|
||||||
|
-.... ue=:. ~`8888~'888X`?888f` X8888 X888h 888R 8b.
|
||||||
|
:88N ` X888 888X '888> 88888 !88888. 888R u888888>
|
||||||
|
9888L X888 888X '888> 88888 %88888 888R 8888R
|
||||||
|
uzu. `8888L X888 888X '888> 88888 '> `8888> 888R 8888P
|
||||||
|
,""888i ?8888 X888 888X '888> `8888L % ?888 ! 888R *888>
|
||||||
|
4 9888L %888> "*88%""*88" '888!` `8888 `-*"" / .888B . 4888
|
||||||
|
' '8888 '88% `~ " `"` "888. :" ^*888% '888
|
||||||
|
"*8Nu.z*" `""***~"` "% 88R
|
||||||
|
88>
|
||||||
|
48
|
||||||
|
'8
|
||||||
|
'''
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
formatter = logging.Formatter("")
|
||||||
|
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
logger.info(banner)
|
||||||
60
src/serve.py
60
src/serve.py
@ -1,57 +1,29 @@
|
|||||||
import logging
|
import logging
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from flask import Flask, request, jsonify
|
|
||||||
from waitress import serve
|
from waitress import serve
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
from image_prediction.predictor import Predictor, extract_image_metadata_pairs, classify_images
|
from image_prediction.flask import make_prediction_server
|
||||||
|
from image_prediction.predictor import Predictor
|
||||||
from image_prediction.response import build_response
|
from image_prediction.response import build_response
|
||||||
|
from image_prediction.utils import get_logger, show_banner
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
predictor = Predictor()
|
def predict(pdf):
|
||||||
logging.info("Predictor ready.")
|
predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar)
|
||||||
|
response = build_response(predictions, metadata)
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
@app.route("/ready", methods=["GET"])
|
|
||||||
def ready():
|
|
||||||
resp = jsonify("OK")
|
|
||||||
resp.status_code = 200
|
|
||||||
return resp
|
|
||||||
|
|
||||||
@app.route("/health", methods=["GET"])
|
|
||||||
def healthy():
|
|
||||||
resp = jsonify("OK")
|
|
||||||
resp.status_code = 200
|
|
||||||
return resp
|
|
||||||
|
|
||||||
@app.route("/", methods=["POST"])
|
|
||||||
def predict():
|
|
||||||
pdf = request.data
|
|
||||||
|
|
||||||
logging.debug("Running predictor on document...")
|
|
||||||
with tempfile.NamedTemporaryFile() as tmp_file:
|
|
||||||
tmp_file.write(pdf)
|
|
||||||
image_metadata_pairs = extract_image_metadata_pairs(tmp_file.name)
|
|
||||||
try:
|
|
||||||
predictions, metadata = classify_images(predictor, image_metadata_pairs)
|
|
||||||
except Exception as err:
|
|
||||||
logging.warning("Analysis failed.")
|
|
||||||
logging.exception(err)
|
|
||||||
response = jsonify("Analysis failed.")
|
|
||||||
response.status_code = 500
|
|
||||||
return response
|
|
||||||
logging.debug(f"Found images in document.")
|
|
||||||
|
|
||||||
response = jsonify(build_response(list(predictions), list(metadata)))
|
|
||||||
|
|
||||||
logging.info("Analysis completed.")
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
run_prediction_server(app, mode=CONFIG.webserver.mode)
|
predictor = Predictor()
|
||||||
|
logger.info("Predictor ready.")
|
||||||
|
|
||||||
|
prediction_server = make_prediction_server(predict)
|
||||||
|
|
||||||
|
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
|
||||||
|
|
||||||
|
|
||||||
def run_prediction_server(app, mode="development"):
|
def run_prediction_server(app, mode="development"):
|
||||||
@ -68,5 +40,9 @@ if __name__ == "__main__":
|
|||||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||||
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
||||||
logging.getLogger("waitress").setLevel(logging.ERROR)
|
logging.getLogger("waitress").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("PIL").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("h5py").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
show_banner()
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user