diff --git a/.coveragerc b/.coveragerc index 81a0e9a..77e78ab 100644 --- a/.coveragerc +++ b/.coveragerc @@ -11,6 +11,8 @@ omit = */env/* */build_venv/* */build_env/* + */utils/banner.py + */utils/logger.py source = image_prediction src @@ -44,6 +46,8 @@ omit = */env/* */build_venv/* */build_env/* + */utils/banner.py + */utils/logger.py ignore_errors = True diff --git a/banner.txt b/banner.txt new file mode 100644 index 0000000..b2ae7b3 --- /dev/null +++ b/banner.txt @@ -0,0 +1,15 @@ + ..... . ... .. + .d88888Neu. 'L xH88"`~ .x8X x .d88" oec : + F""""*8888888F .. . : :8888 .f"8888Hf 5888R @88888 + * `"*88*" .888: x888 x888. :8888> X8L ^""` '888R 8"*88% + -.... ue=:. ~`8888~'888X`?888f` X8888 X888h 888R 8b. + :88N ` X888 888X '888> 88888 !88888. 888R u888888> + 9888L X888 888X '888> 88888 %88888 888R 8888R + uzu. `8888L X888 888X '888> 88888 '> `8888> 888R 8888P +,""888i ?8888 X888 888X '888> `8888L % ?888 ! 888R *888> +4 9888L %888> "*88%""*88" '888!` `8888 `-*"" / .888B . 4888 +' '8888 '88% `~ " `"` "888. :" ^*888% '888 + "*8Nu.z*" `""***~"` "% 88R + 88> + 48 + '8 \ No newline at end of file diff --git a/deprecated/predictor.py b/deprecated/predictor.py new file mode 100644 index 0000000..2dad683 --- /dev/null +++ b/deprecated/predictor.py @@ -0,0 +1,122 @@ +from itertools import chain +from operator import itemgetter +from typing import List, Dict, Iterable + +import numpy as np + +from image_prediction.config import CONFIG +from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS +from image_prediction.utils import temporary_pdf_file, get_logger +from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle +from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch +from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader +from incl.redai_image.redai.redai.utils.shared import chunk_iterable + +logger = get_logger() + + +class Predictor: + """`ModelHandle` wrapper. Forwards to wrapped service_estimator handle for prediction and produces structured output that is + interpretable independently of the wrapped service_estimator (e.g. with regard to a .classes_ attribute). + """ + + def __init__(self, model_handle: ModelHandle = None): + """Initializes a ServiceEstimator. + + Args: + model_handle: ModelHandle object to forward to for prediction. By default, a service_estimator handle is loaded from the + mlflow database via CONFIG.service.run_id. + """ + try: + if model_handle is None: + reader = MlflowModelReader(run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR) + self.model_handle = reader.get_model_handle(BASE_WEIGHTS) + else: + self.model_handle = model_handle + + self.classes = self.model_handle.model.classes_ + self.classes_readable = np.array(self.model_handle.classes) + self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]] + except Exception as e: + logger.info(f"Service estimator initialization failed: {e}") + + def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]: + """Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to + probabilities. + + Args: + probs: probability matrix (items x classes) + + Returns: + list of mappings from classes to probabilities. + """ + classes = np.argmax(probs, axis=1) + classes = self.classes[classes] + classes_readable = [self.model_handle.classes[c] for c in classes] + return classes_readable + + def predict(self, images: List, probabilities: bool = False, **kwargs): + """Gathers predictions for list of images. Assigns each image a class and optionally a probability distribution + over all classes. + + Args: + images (List[PIL.Image]) : Images to gather predictions for. + probabilities: Whether to return dictionaries of the following form instead of strings: + { + "class": predicted class, + "probabilities": { + "class 1" : class 1 probability, + "class 2" : class 2 probability, + ... + } + } + + Returns: + By default the return value is a list of classes (meaningful class name strings). Alternatively a list of + dictionaries with an additional probability field for estimated class probabilities per image can be + returned. + """ + X = self.model_handle.prep_images(list(images)) + + probs_per_item = self.model_handle.model.predict_proba(X, **kwargs).astype(float) + classes = self.__make_predictions_human_readable(probs_per_item) + + class2prob_per_item = [dict(zip(self.classes_readable_aligned, probs)) for probs in probs_per_item] + class2prob_per_item = [ + dict(sorted(c2p.items(), key=itemgetter(1), reverse=True)) for c2p in class2prob_per_item + ] + + predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)] + + return predictions if probabilities else classes + + def predict_pdf(self, pdf, verbose=False): + with temporary_pdf_file(pdf) as pdf_path: + image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path, verbose=verbose) + return self.__predict_images(image_metadata_pairs) + + def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): + def process_chunk(chunk): + images, metadata = zip(*chunk) + predictions = self.predict(images, probabilities=True) + return predictions, metadata + + def predict(image_metadata_pair_generator): + chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size) + return map(chain.from_iterable, zip(*map(process_chunk, chunks))) + + try: + predictions, metadata = predict(image_metadata_pairs) + return predictions, metadata + + except ValueError: + return [], [] + + @staticmethod + def __extract_image_metadata_pairs(pdf_path: str, **kwargs): + def image_is_large_enough(metadata: dict): + x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) + + return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 + + yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) diff --git a/deprecated/serve.py b/deprecated/serve.py new file mode 100644 index 0000000..989a0da --- /dev/null +++ b/deprecated/serve.py @@ -0,0 +1,49 @@ +import logging + +from waitress import serve + +from image_prediction.config import CONFIG +from image_prediction.flask import make_prediction_server +from image_prediction.predictor import Predictor +from image_prediction.response import build_response +from image_prediction.utils import get_logger, show_banner + +logger = get_logger() + + +def main(): + def predict(pdf): + # Keras service_estimator.predict stalls when service_estimator was loaded in different process + # https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python + predictor = Predictor() + predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar) + response = build_response(predictions, metadata) + return response + + logger.info("Predictor ready.") + + prediction_server = make_prediction_server(predict) + + run_prediction_server(prediction_server, mode=CONFIG.webserver.mode) + + +def run_prediction_server(app, mode="development"): + if mode == "development": + app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True) + elif mode == "production": + serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port) + + +if __name__ == "__main__": + logging_level = CONFIG.service.logging_level + logging.basicConfig(level=logging_level) + logging.getLogger("flask").setLevel(logging.ERROR) + logging.getLogger("urllib3").setLevel(logging.ERROR) + logging.getLogger("werkzeug").setLevel(logging.ERROR) + logging.getLogger("waitress").setLevel(logging.ERROR) + logging.getLogger("PIL").setLevel(logging.ERROR) + logging.getLogger("h5py").setLevel(logging.ERROR) + + show_banner() + + main() diff --git a/doc/tests.drawio b/doc/tests.drawio new file mode 100644 index 0000000..c335abc --- /dev/null +++ b/doc/tests.drawio @@ -0,0 +1 @@ +1ZZRT6QwEMc/DY8mQHdRX93z9JLbmNzGmNxbQ0daLQzpDrL46a/IsCzinneJcd0XaP+dtsN/fkADscg3V06WeokKbBCHahOIb0Ecnydzf22FphPmyXknZM6oTooGYWWegcWQ1cooWI8CCdGSKcdiikUBKY006RzW47B7tONdS5nBRFil0k7VO6NId+rZPBz0azCZ7neOQh7JZR/MwlpLhfWOJC4DsXCI1LXyzQJs613vSzfv+57RbWIOCvqXCZqW9PBref27aZ7xsQ5vTn/cnvAqT9JW/MCwJuNzR8dZU9Nb4bAqFLSrhYG4qLUhWJUybUdrX3uvacqt70W+yeuCI9jsTTja2uDxAcyBXONDeILonWN04hn366EQUR+jd4qQsCa59tl26cEe32CH/sOt+TueoCONGRbS/kQs2YkHIGoYbFkRvuUTqAmFr1zyu2LlUvhLdjG/HtJlQO/VfOq6AyvJPI3z+HAL4wlwpbp/2V0qODxzUTJmLjo4c8nEkxaWFXcLLPzt4ithKI4BQzHBMOc/l8UvAeLrj9/hQTw9NhBnxwDibB+IB+ZvdvZ5/PnucAx6Gds5S4rLPw== \ No newline at end of file diff --git a/image_prediction/locations.py b/image_prediction/locations.py index 8d62079..1f14c1a 100644 --- a/image_prediction/locations.py +++ b/image_prediction/locations.py @@ -1,11 +1,17 @@ -from os import path +"""Defines constant paths relative to the module root path.""" -MODULE_DIR = path.dirname(path.abspath(__file__)) -PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR) +from pathlib import Path -CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml") +MODULE_DIR = Path(__file__).resolve().parents[0] -DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data") -MLRUNS_DIR = path.join(DATA_DIR, "mlruns") +PACKAGE_ROOT_DIR = MODULE_DIR.parents[0] -TEST_DATA_DIR = path.join(PACKAGE_ROOT_DIR, "test", "data") +CONFIG_FILE = PACKAGE_ROOT_DIR / "config.yaml" + +BANNER_FILE = PACKAGE_ROOT_DIR / "banner.txt" + +DATA_DIR = PACKAGE_ROOT_DIR / "data" + +MLRUNS_DIR = str(DATA_DIR / "mlruns") + +TEST_DATA_DIR = PACKAGE_ROOT_DIR / "test" / "data" diff --git a/image_prediction/utils.py b/image_prediction/utils.py index d138381..b28b04f 100644 --- a/image_prediction/utils.py +++ b/image_prediction/utils.py @@ -1,82 +1,3 @@ -import logging -import tempfile -from contextlib import contextmanager -from functools import reduce -from itertools import takewhile, starmap, islice, repeat -from operator import truth - -from image_prediction.config import CONFIG -from redai.utils import export -@contextmanager -def temporary_pdf_file(pdf: bytes): - with tempfile.NamedTemporaryFile() as f: - f.write(pdf) - yield f.name - -def make_logger_getter(): - - logger = logging.getLogger("imclf") - logger.propagate = False - - handler = logging.StreamHandler() - handler.setLevel(CONFIG.service.logging_level) - - log_format = "[%(levelname)s]: %(message)s" - formatter = logging.Formatter(log_format) - - handler.setFormatter(formatter) - logger.addHandler(handler) - - def get_logger(): - return logger - - return get_logger - - -get_logger = make_logger_getter() - - -def show_banner(): - banner = ''' - ..... . ... .. - .d88888Neu. 'L xH88"`~ .x8X x .d88" oec : - F""""*8888888F .. . : :8888 .f"8888Hf 5888R @88888 - * `"*88*" .888: x888 x888. :8888> X8L ^""` '888R 8"*88% - -.... ue=:. ~`8888~'888X`?888f` X8888 X888h 888R 8b. - :88N ` X888 888X '888> 88888 !88888. 888R u888888> - 9888L X888 888X '888> 88888 %88888 888R 8888R - uzu. `8888L X888 888X '888> 88888 '> `8888> 888R 8888P -,""888i ?8888 X888 888X '888> `8888L % ?888 ! 888R *888> -4 9888L %888> "*88%""*88" '888!` `8888 `-*"" / .888B . 4888 -' '8888 '88% `~ " `"` "888. :" ^*888% '888 - "*8Nu.z*" `""***~"` "% 88R - 88> - 48 - '8 - ''' - - logger = logging.getLogger(__name__) - logger.propagate = False - - handler = logging.StreamHandler() - handler.setLevel(logging.INFO) - - formatter = logging.Formatter("") - - handler.setFormatter(formatter) - logger.addHandler(handler) - - logger.info(banner) - - -@export -def chunk_iterable(iterable, chunk_size): - return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size))))) - - -def compose(func, *funcs): - funcs = [func, *funcs] - return lambda x: reduce(lambda acc, f: f(acc), funcs, x) diff --git a/image_prediction/utils/__init__.py b/image_prediction/utils/__init__.py new file mode 100644 index 0000000..d8ef2e6 --- /dev/null +++ b/image_prediction/utils/__init__.py @@ -0,0 +1,8 @@ +from itertools import takewhile, starmap, islice, repeat +from operator import truth + +from .logger import get_logger + + +def chunk_iterable(iterable, chunk_size): + return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size))))) diff --git a/image_prediction/utils/banner.py b/image_prediction/utils/banner.py new file mode 100644 index 0000000..6a17d93 --- /dev/null +++ b/image_prediction/utils/banner.py @@ -0,0 +1,21 @@ +import logging + +from image_prediction.locations import BANNER_FILE + + +def show_banner(): + with open(BANNER_FILE) as f: + banner = "\n" + "".join(f.readlines()) + "\n" + + logger = logging.getLogger(__name__) + logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + + formatter = logging.Formatter("") + + handler.setFormatter(formatter) + logger.addHandler(handler) + + logger.info(banner) diff --git a/image_prediction/utils/logger.py b/image_prediction/utils/logger.py new file mode 100644 index 0000000..4f5186f --- /dev/null +++ b/image_prediction/utils/logger.py @@ -0,0 +1,26 @@ +import logging + +from image_prediction.config import CONFIG + + +def make_logger_getter(): + logger = logging.getLogger("imclf") + logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(CONFIG.service.logging_level) + + log_format = "[%(levelname)s]: %(message)s" + formatter = logging.Formatter(log_format) + + handler.setFormatter(formatter) + logger.addHandler(handler) + + def get_logger(): + return logger + + return get_logger + + +get_logger = make_logger_getter() +1 \ No newline at end of file