refactoring
This commit is contained in:
parent
5caa9807e2
commit
268b83a1ff
@ -11,6 +11,8 @@ omit =
|
|||||||
*/env/*
|
*/env/*
|
||||||
*/build_venv/*
|
*/build_venv/*
|
||||||
*/build_env/*
|
*/build_env/*
|
||||||
|
*/utils/banner.py
|
||||||
|
*/utils/logger.py
|
||||||
source =
|
source =
|
||||||
image_prediction
|
image_prediction
|
||||||
src
|
src
|
||||||
@ -44,6 +46,8 @@ omit =
|
|||||||
*/env/*
|
*/env/*
|
||||||
*/build_venv/*
|
*/build_venv/*
|
||||||
*/build_env/*
|
*/build_env/*
|
||||||
|
*/utils/banner.py
|
||||||
|
*/utils/logger.py
|
||||||
|
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|||||||
15
banner.txt
Normal file
15
banner.txt
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
..... . ... ..
|
||||||
|
.d88888Neu. 'L xH88"`~ .x8X x .d88" oec :
|
||||||
|
F""""*8888888F .. . : :8888 .f"8888Hf 5888R @88888
|
||||||
|
* `"*88*" .888: x888 x888. :8888> X8L ^""` '888R 8"*88%
|
||||||
|
-.... ue=:. ~`8888~'888X`?888f` X8888 X888h 888R 8b.
|
||||||
|
:88N ` X888 888X '888> 88888 !88888. 888R u888888>
|
||||||
|
9888L X888 888X '888> 88888 %88888 888R 8888R
|
||||||
|
uzu. `8888L X888 888X '888> 88888 '> `8888> 888R 8888P
|
||||||
|
,""888i ?8888 X888 888X '888> `8888L % ?888 ! 888R *888>
|
||||||
|
4 9888L %888> "*88%""*88" '888!` `8888 `-*"" / .888B . 4888
|
||||||
|
' '8888 '88% `~ " `"` "888. :" ^*888% '888
|
||||||
|
"*8Nu.z*" `""***~"` "% 88R
|
||||||
|
88>
|
||||||
|
48
|
||||||
|
'8
|
||||||
122
deprecated/predictor.py
Normal file
122
deprecated/predictor.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from operator import itemgetter
|
||||||
|
from typing import List, Dict, Iterable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from image_prediction.config import CONFIG
|
||||||
|
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
|
||||||
|
from image_prediction.utils import temporary_pdf_file, get_logger
|
||||||
|
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
|
||||||
|
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
||||||
|
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
|
||||||
|
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class Predictor:
|
||||||
|
"""`ModelHandle` wrapper. Forwards to wrapped service_estimator handle for prediction and produces structured output that is
|
||||||
|
interpretable independently of the wrapped service_estimator (e.g. with regard to a .classes_ attribute).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_handle: ModelHandle = None):
|
||||||
|
"""Initializes a ServiceEstimator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_handle: ModelHandle object to forward to for prediction. By default, a service_estimator handle is loaded from the
|
||||||
|
mlflow database via CONFIG.service.run_id.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if model_handle is None:
|
||||||
|
reader = MlflowModelReader(run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR)
|
||||||
|
self.model_handle = reader.get_model_handle(BASE_WEIGHTS)
|
||||||
|
else:
|
||||||
|
self.model_handle = model_handle
|
||||||
|
|
||||||
|
self.classes = self.model_handle.model.classes_
|
||||||
|
self.classes_readable = np.array(self.model_handle.classes)
|
||||||
|
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Service estimator initialization failed: {e}")
|
||||||
|
|
||||||
|
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
|
||||||
|
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to
|
||||||
|
probabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs: probability matrix (items x classes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of mappings from classes to probabilities.
|
||||||
|
"""
|
||||||
|
classes = np.argmax(probs, axis=1)
|
||||||
|
classes = self.classes[classes]
|
||||||
|
classes_readable = [self.model_handle.classes[c] for c in classes]
|
||||||
|
return classes_readable
|
||||||
|
|
||||||
|
def predict(self, images: List, probabilities: bool = False, **kwargs):
|
||||||
|
"""Gathers predictions for list of images. Assigns each image a class and optionally a probability distribution
|
||||||
|
over all classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (List[PIL.Image]) : Images to gather predictions for.
|
||||||
|
probabilities: Whether to return dictionaries of the following form instead of strings:
|
||||||
|
{
|
||||||
|
"class": predicted class,
|
||||||
|
"probabilities": {
|
||||||
|
"class 1" : class 1 probability,
|
||||||
|
"class 2" : class 2 probability,
|
||||||
|
...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
By default the return value is a list of classes (meaningful class name strings). Alternatively a list of
|
||||||
|
dictionaries with an additional probability field for estimated class probabilities per image can be
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
X = self.model_handle.prep_images(list(images))
|
||||||
|
|
||||||
|
probs_per_item = self.model_handle.model.predict_proba(X, **kwargs).astype(float)
|
||||||
|
classes = self.__make_predictions_human_readable(probs_per_item)
|
||||||
|
|
||||||
|
class2prob_per_item = [dict(zip(self.classes_readable_aligned, probs)) for probs in probs_per_item]
|
||||||
|
class2prob_per_item = [
|
||||||
|
dict(sorted(c2p.items(), key=itemgetter(1), reverse=True)) for c2p in class2prob_per_item
|
||||||
|
]
|
||||||
|
|
||||||
|
predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)]
|
||||||
|
|
||||||
|
return predictions if probabilities else classes
|
||||||
|
|
||||||
|
def predict_pdf(self, pdf, verbose=False):
|
||||||
|
with temporary_pdf_file(pdf) as pdf_path:
|
||||||
|
image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path, verbose=verbose)
|
||||||
|
return self.__predict_images(image_metadata_pairs)
|
||||||
|
|
||||||
|
def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
||||||
|
def process_chunk(chunk):
|
||||||
|
images, metadata = zip(*chunk)
|
||||||
|
predictions = self.predict(images, probabilities=True)
|
||||||
|
return predictions, metadata
|
||||||
|
|
||||||
|
def predict(image_metadata_pair_generator):
|
||||||
|
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
||||||
|
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
||||||
|
|
||||||
|
try:
|
||||||
|
predictions, metadata = predict(image_metadata_pairs)
|
||||||
|
return predictions, metadata
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
||||||
|
def image_is_large_enough(metadata: dict):
|
||||||
|
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
||||||
|
|
||||||
|
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
||||||
|
|
||||||
|
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
||||||
49
deprecated/serve.py
Normal file
49
deprecated/serve.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from waitress import serve
|
||||||
|
|
||||||
|
from image_prediction.config import CONFIG
|
||||||
|
from image_prediction.flask import make_prediction_server
|
||||||
|
from image_prediction.predictor import Predictor
|
||||||
|
from image_prediction.response import build_response
|
||||||
|
from image_prediction.utils import get_logger, show_banner
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
def predict(pdf):
|
||||||
|
# Keras service_estimator.predict stalls when service_estimator was loaded in different process
|
||||||
|
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
|
||||||
|
predictor = Predictor()
|
||||||
|
predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar)
|
||||||
|
response = build_response(predictions, metadata)
|
||||||
|
return response
|
||||||
|
|
||||||
|
logger.info("Predictor ready.")
|
||||||
|
|
||||||
|
prediction_server = make_prediction_server(predict)
|
||||||
|
|
||||||
|
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
|
||||||
|
|
||||||
|
|
||||||
|
def run_prediction_server(app, mode="development"):
|
||||||
|
if mode == "development":
|
||||||
|
app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True)
|
||||||
|
elif mode == "production":
|
||||||
|
serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging_level = CONFIG.service.logging_level
|
||||||
|
logging.basicConfig(level=logging_level)
|
||||||
|
logging.getLogger("flask").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("waitress").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("PIL").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("h5py").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
show_banner()
|
||||||
|
|
||||||
|
main()
|
||||||
1
doc/tests.drawio
Normal file
1
doc/tests.drawio
Normal file
@ -0,0 +1 @@
|
|||||||
|
<mxfile host="app.diagrams.net" modified="2022-03-17T15:35:10.371Z" agent="5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.51 Safari/537.36" etag="b-CbBXg6FXQ9T3Px-oLc" version="17.1.1" type="device"><diagram id="tS3WR_Pr6QhNVK3FqSUP" name="Page-1">1ZZRT6QwEMc/DY8mQHdRX93z9JLbmNzGmNxbQ0daLQzpDrL46a/IsCzinneJcd0XaP+dtsN/fkADscg3V06WeokKbBCHahOIb0Ecnydzf22FphPmyXknZM6oTooGYWWegcWQ1cooWI8CCdGSKcdiikUBKY006RzW47B7tONdS5nBRFil0k7VO6NId+rZPBz0azCZ7neOQh7JZR/MwlpLhfWOJC4DsXCI1LXyzQJs613vSzfv+57RbWIOCvqXCZqW9PBref27aZ7xsQ5vTn/cnvAqT9JW/MCwJuNzR8dZU9Nb4bAqFLSrhYG4qLUhWJUybUdrX3uvacqt70W+yeuCI9jsTTja2uDxAcyBXONDeILonWN04hn366EQUR+jd4qQsCa59tl26cEe32CH/sOt+TueoCONGRbS/kQs2YkHIGoYbFkRvuUTqAmFr1zyu2LlUvhLdjG/HtJlQO/VfOq6AyvJPI3z+HAL4wlwpbp/2V0qODxzUTJmLjo4c8nEkxaWFXcLLPzt4ithKI4BQzHBMOc/l8UvAeLrj9/hQTw9NhBnxwDibB+IB+ZvdvZ5/PnucAx6Gds5S4rLPw==</diagram></mxfile>
|
||||||
@ -1,11 +1,17 @@
|
|||||||
from os import path
|
"""Defines constant paths relative to the module root path."""
|
||||||
|
|
||||||
MODULE_DIR = path.dirname(path.abspath(__file__))
|
from pathlib import Path
|
||||||
PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR)
|
|
||||||
|
|
||||||
CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
|
MODULE_DIR = Path(__file__).resolve().parents[0]
|
||||||
|
|
||||||
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
|
PACKAGE_ROOT_DIR = MODULE_DIR.parents[0]
|
||||||
MLRUNS_DIR = path.join(DATA_DIR, "mlruns")
|
|
||||||
|
|
||||||
TEST_DATA_DIR = path.join(PACKAGE_ROOT_DIR, "test", "data")
|
CONFIG_FILE = PACKAGE_ROOT_DIR / "config.yaml"
|
||||||
|
|
||||||
|
BANNER_FILE = PACKAGE_ROOT_DIR / "banner.txt"
|
||||||
|
|
||||||
|
DATA_DIR = PACKAGE_ROOT_DIR / "data"
|
||||||
|
|
||||||
|
MLRUNS_DIR = str(DATA_DIR / "mlruns")
|
||||||
|
|
||||||
|
TEST_DATA_DIR = PACKAGE_ROOT_DIR / "test" / "data"
|
||||||
|
|||||||
@ -1,82 +1,3 @@
|
|||||||
import logging
|
|
||||||
import tempfile
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from functools import reduce
|
|
||||||
from itertools import takewhile, starmap, islice, repeat
|
|
||||||
from operator import truth
|
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
|
||||||
from redai.utils import export
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def temporary_pdf_file(pdf: bytes):
|
|
||||||
with tempfile.NamedTemporaryFile() as f:
|
|
||||||
f.write(pdf)
|
|
||||||
yield f.name
|
|
||||||
|
|
||||||
|
|
||||||
def make_logger_getter():
|
|
||||||
|
|
||||||
logger = logging.getLogger("imclf")
|
|
||||||
logger.propagate = False
|
|
||||||
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setLevel(CONFIG.service.logging_level)
|
|
||||||
|
|
||||||
log_format = "[%(levelname)s]: %(message)s"
|
|
||||||
formatter = logging.Formatter(log_format)
|
|
||||||
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
def get_logger():
|
|
||||||
return logger
|
|
||||||
|
|
||||||
return get_logger
|
|
||||||
|
|
||||||
|
|
||||||
get_logger = make_logger_getter()
|
|
||||||
|
|
||||||
|
|
||||||
def show_banner():
|
|
||||||
banner = '''
|
|
||||||
..... . ... ..
|
|
||||||
.d88888Neu. 'L xH88"`~ .x8X x .d88" oec :
|
|
||||||
F""""*8888888F .. . : :8888 .f"8888Hf 5888R @88888
|
|
||||||
* `"*88*" .888: x888 x888. :8888> X8L ^""` '888R 8"*88%
|
|
||||||
-.... ue=:. ~`8888~'888X`?888f` X8888 X888h 888R 8b.
|
|
||||||
:88N ` X888 888X '888> 88888 !88888. 888R u888888>
|
|
||||||
9888L X888 888X '888> 88888 %88888 888R 8888R
|
|
||||||
uzu. `8888L X888 888X '888> 88888 '> `8888> 888R 8888P
|
|
||||||
,""888i ?8888 X888 888X '888> `8888L % ?888 ! 888R *888>
|
|
||||||
4 9888L %888> "*88%""*88" '888!` `8888 `-*"" / .888B . 4888
|
|
||||||
' '8888 '88% `~ " `"` "888. :" ^*888% '888
|
|
||||||
"*8Nu.z*" `""***~"` "% 88R
|
|
||||||
88>
|
|
||||||
48
|
|
||||||
'8
|
|
||||||
'''
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.propagate = False
|
|
||||||
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
formatter = logging.Formatter("")
|
|
||||||
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
logger.info(banner)
|
|
||||||
|
|
||||||
|
|
||||||
@export
|
|
||||||
def chunk_iterable(iterable, chunk_size):
|
|
||||||
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
|
||||||
|
|
||||||
|
|
||||||
def compose(func, *funcs):
|
|
||||||
funcs = [func, *funcs]
|
|
||||||
return lambda x: reduce(lambda acc, f: f(acc), funcs, x)
|
|
||||||
|
|||||||
8
image_prediction/utils/__init__.py
Normal file
8
image_prediction/utils/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from itertools import takewhile, starmap, islice, repeat
|
||||||
|
from operator import truth
|
||||||
|
|
||||||
|
from .logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_iterable(iterable, chunk_size):
|
||||||
|
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
||||||
21
image_prediction/utils/banner.py
Normal file
21
image_prediction/utils/banner.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from image_prediction.locations import BANNER_FILE
|
||||||
|
|
||||||
|
|
||||||
|
def show_banner():
|
||||||
|
with open(BANNER_FILE) as f:
|
||||||
|
banner = "\n" + "".join(f.readlines()) + "\n"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
formatter = logging.Formatter("")
|
||||||
|
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
logger.info(banner)
|
||||||
26
image_prediction/utils/logger.py
Normal file
26
image_prediction/utils/logger.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from image_prediction.config import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def make_logger_getter():
|
||||||
|
logger = logging.getLogger("imclf")
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(CONFIG.service.logging_level)
|
||||||
|
|
||||||
|
log_format = "[%(levelname)s]: %(message)s"
|
||||||
|
formatter = logging.Formatter(log_format)
|
||||||
|
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
def get_logger():
|
||||||
|
return logger
|
||||||
|
|
||||||
|
return get_logger
|
||||||
|
|
||||||
|
|
||||||
|
get_logger = make_logger_getter()
|
||||||
|
1
|
||||||
Loading…
x
Reference in New Issue
Block a user