Compare commits
227 Commits
master
...
fuzzy_stit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03e7b00cfd | ||
|
|
7aee00cb49 | ||
|
|
2cc52c4630 | ||
|
|
daa1da3a50 | ||
|
|
6a7debde14 | ||
|
|
b4f279c549 | ||
|
|
f5881f2229 | ||
|
|
62bfedfea8 | ||
|
|
1d88876ab1 | ||
|
|
bbafad5561 | ||
|
|
f17a232009 | ||
|
|
88a46ae7cd | ||
|
|
e82a81f5c8 | ||
|
|
35c5b15e32 | ||
|
|
698e647c6f | ||
|
|
d8f86d14a5 | ||
|
|
bb7c1be630 | ||
|
|
79cd31850d | ||
|
|
3d335783dc | ||
|
|
bb79f9dd55 | ||
|
|
585cdf5c70 | ||
|
|
04cf0245ed | ||
|
|
3530ef72c5 | ||
|
|
d80af336eb | ||
|
|
bcf6dc5c47 | ||
|
|
f4c0547405 | ||
|
|
1bea5fb9a8 | ||
|
|
57440f5106 | ||
|
|
710783a2f8 | ||
|
|
887b8339a2 | ||
|
|
43cb0fffed | ||
|
|
6e7645e319 | ||
|
|
3b18fc6158 | ||
|
|
1b10445f91 | ||
|
|
5967149c49 | ||
|
|
303970db51 | ||
|
|
51793d19e9 | ||
|
|
e276a5ec27 | ||
|
|
7e6fe7cf11 | ||
|
|
bb5db1b4ef | ||
|
|
8ac9fcb19f | ||
|
|
160973e2be | ||
|
|
803cc57155 | ||
|
|
50b4d239cb | ||
|
|
9bb07f95fb | ||
|
|
29028cc1a5 | ||
|
|
2fcb0bd149 | ||
|
|
3e882dc247 | ||
|
|
2b1e7cbb08 | ||
|
|
5e8b55ef10 | ||
|
|
3266e0af58 | ||
|
|
7e2696d5c5 | ||
|
|
302613bf2b | ||
|
|
66fd103d1b | ||
|
|
6e5d6912ed | ||
|
|
b1efb5ed09 | ||
|
|
ef70e11352 | ||
|
|
315679468b | ||
|
|
64e3350dee | ||
|
|
6a7e0e1000 | ||
|
|
11fc63035d | ||
|
|
4bc295b212 | ||
|
|
4c46be4abc | ||
|
|
37ee086b5d | ||
|
|
1fd30e68b6 | ||
|
|
2c908162f1 | ||
|
|
4756b8c9bd | ||
|
|
e0885c545a | ||
|
|
fdb7ebe618 | ||
|
|
ce69f7d160 | ||
|
|
8f61c4cba2 | ||
|
|
f3e2b2335f | ||
|
|
9cda65ad41 | ||
|
|
692e72b3b2 | ||
|
|
38869d52c6 | ||
|
|
e01b5c9acd | ||
|
|
6a6fc19958 | ||
|
|
1b1f1aafef | ||
|
|
caef37376b | ||
|
|
16aa951c96 | ||
|
|
89afb8f920 | ||
|
|
1ffc9dcc68 | ||
|
|
0976971117 | ||
|
|
b4b0058475 | ||
|
|
2ee36dcb54 | ||
|
|
ab382646b7 | ||
|
|
8c916a79c3 | ||
|
|
3ff6dac2e0 | ||
|
|
d134884553 | ||
|
|
2d0545c928 | ||
|
|
65a4a8e34e | ||
|
|
39c111fd42 | ||
|
|
0376223c9d | ||
|
|
bf85ef357c | ||
|
|
f6a7a14a20 | ||
|
|
41f783dc5d | ||
|
|
32397256c8 | ||
|
|
f44e6f4fd7 | ||
|
|
3d2c97bc10 | ||
|
|
9663cec12d | ||
|
|
c1c3f541d4 | ||
|
|
4d86e78307 | ||
|
|
1cf6ab256c | ||
|
|
a89e374c67 | ||
|
|
0861e22542 | ||
|
|
7827869af4 | ||
|
|
613bba8cfc | ||
|
|
5c23898280 | ||
|
|
e8d0299e46 | ||
|
|
cb00aed62c | ||
|
|
1501653673 | ||
|
|
b4b929b65f | ||
|
|
3d1c251e10 | ||
|
|
c80549d5d3 | ||
|
|
070749880e | ||
|
|
94783c54f2 | ||
|
|
2b48c6108b | ||
|
|
da9b3d0cb9 | ||
|
|
c372529ee5 | ||
|
|
1a1ece1f95 | ||
|
|
426061e5ea | ||
|
|
7c2cf44ad0 | ||
|
|
c125e1ff6c | ||
|
|
dd007891c7 | ||
|
|
d3257fdeda | ||
|
|
1581880ec6 | ||
|
|
268b83a1ff | ||
|
|
5caa9807e2 | ||
|
|
82added50a | ||
|
|
b6ccfbcf8f | ||
|
|
e17912caa9 | ||
|
|
3eaf9dc0e1 | ||
|
|
0cefef4e15 | ||
|
|
4f94cbd68d | ||
|
|
2517b45d44 | ||
|
|
2a62ad7aba | ||
|
|
20c980dbe6 | ||
|
|
726298b155 | ||
|
|
479afbcd34 | ||
|
|
4ab9f0d89b | ||
|
|
d4604a2cb5 | ||
|
|
4ebb36247e | ||
|
|
7ec7390e90 | ||
|
|
dc1cdde458 | ||
|
|
0921ef9a4f | ||
|
|
91dd467142 | ||
|
|
b3e1604ecc | ||
|
|
20718996bd | ||
|
|
cc8d87338c | ||
|
|
258c1ab02d | ||
|
|
ce3d33955e | ||
|
|
a95cc4e06b | ||
|
|
6d1ace473b | ||
|
|
0a22a35912 | ||
|
|
a5d3232dd0 | ||
|
|
49f9847d9a | ||
|
|
1c6f5749dd | ||
|
|
8bccec277f | ||
|
|
7f37f841dd | ||
|
|
8c7e3e29f5 | ||
|
|
99d8e921db | ||
|
|
6835394d30 | ||
|
|
ad6bb80900 | ||
|
|
95209a5c9d | ||
|
|
45a07c620a | ||
|
|
81ab9a5f53 | ||
|
|
8b15ac6df4 | ||
|
|
e9489287bd | ||
|
|
15c0b73034 | ||
|
|
7a64af156b | ||
|
|
60617fd622 | ||
|
|
ade318c7b7 | ||
|
|
3339ed2eab | ||
|
|
7340fb6dda | ||
|
|
358d7ecd91 | ||
|
|
d33a882d65 | ||
|
|
06adedac57 | ||
|
|
edbc5c3f84 | ||
|
|
f60bafd007 | ||
|
|
a1c7dd4a8d | ||
|
|
6b58756103 | ||
|
|
3b4c2a40b2 | ||
|
|
c06905625d | ||
|
|
d44622dddc | ||
|
|
3c6dfed508 | ||
|
|
f18e183ab0 | ||
|
|
86f2abc553 | ||
|
|
f0a8f2224c | ||
|
|
9bf1dcbe1d | ||
|
|
9ce7b6e6da | ||
|
|
e818b05472 | ||
|
|
b818ee4724 | ||
|
|
9461be29d5 | ||
|
|
2631eb5c0f | ||
|
|
643ab99bd3 | ||
|
|
e0ab365bb9 | ||
|
|
48737d9439 | ||
|
|
a5147c9a58 | ||
|
|
4c939464b0 | ||
|
|
334dc79f7e | ||
|
|
9d58ae714f | ||
|
|
0f811bdc56 | ||
|
|
d11333981f | ||
|
|
4fcd1e79d3 | ||
|
|
5c5d132d7f | ||
|
|
0f9510906d | ||
|
|
6343229c1e | ||
|
|
7d21b0a585 | ||
|
|
364111db89 | ||
|
|
ea298dacfa | ||
|
|
373c619b0c | ||
|
|
8aa0717007 | ||
|
|
a3215e0bc3 | ||
|
|
c64bff0843 | ||
|
|
dd18087261 | ||
|
|
d97b477208 | ||
|
|
981d7816a0 | ||
|
|
2e36a9d46d | ||
|
|
03f269c2d7 | ||
|
|
6853d862ed | ||
|
|
31591bef0f | ||
|
|
7834a65ff5 | ||
|
|
8b7293be09 | ||
|
|
9c9070e8bf | ||
|
|
e8fb01b4b7 | ||
|
|
41f0cc8a41 | ||
|
|
ee959346b7 |
11
.coveragerc
11
.coveragerc
@ -1,6 +1,9 @@
|
|||||||
# .coveragerc to control coverage.py
|
# .coveragerc to control coverage.py
|
||||||
[run]
|
[run]
|
||||||
branch = True
|
branch = True
|
||||||
|
parallel = True
|
||||||
|
command_line = -m pytest
|
||||||
|
concurrency = multiprocessing
|
||||||
omit =
|
omit =
|
||||||
*/site-packages/*
|
*/site-packages/*
|
||||||
*/distutils/*
|
*/distutils/*
|
||||||
@ -11,9 +14,11 @@ omit =
|
|||||||
*/env/*
|
*/env/*
|
||||||
*/build_venv/*
|
*/build_venv/*
|
||||||
*/build_env/*
|
*/build_env/*
|
||||||
|
*/utils/banner.py
|
||||||
|
*/utils/logger.py
|
||||||
|
*/src/*
|
||||||
source =
|
source =
|
||||||
image_prediction
|
image_prediction
|
||||||
src
|
|
||||||
relative_files = True
|
relative_files = True
|
||||||
data_file = .coverage
|
data_file = .coverage
|
||||||
|
|
||||||
@ -44,6 +49,10 @@ omit =
|
|||||||
*/env/*
|
*/env/*
|
||||||
*/build_venv/*
|
*/build_venv/*
|
||||||
*/build_env/*
|
*/build_env/*
|
||||||
|
*/utils/banner.py
|
||||||
|
*/utils/logger.py
|
||||||
|
*/src/*
|
||||||
|
*/pdf_annotation.py
|
||||||
|
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
[core]
|
[core]
|
||||||
remote = vector
|
remote = vector
|
||||||
|
autostage = true
|
||||||
['remote "vector"']
|
['remote "vector"']
|
||||||
url = ssh://vector.iqser.com/research/image_service/
|
url = ssh://vector.iqser.com/research/image-prediction/
|
||||||
port = 22
|
port = 22
|
||||||
|
|||||||
6
.gitignore
vendored
6
.gitignore
vendored
@ -32,6 +32,8 @@
|
|||||||
**/classpath-data.json
|
**/classpath-data.json
|
||||||
**/dependencies-and-licenses-overview.txt
|
**/dependencies-and-licenses-overview.txt
|
||||||
|
|
||||||
|
.coverage
|
||||||
|
|
||||||
|
|
||||||
*__pycache__
|
*__pycache__
|
||||||
*.egg-info*
|
*.egg-info*
|
||||||
@ -44,7 +46,7 @@
|
|||||||
*misc
|
*misc
|
||||||
|
|
||||||
/coverage_html_report/
|
/coverage_html_report/
|
||||||
.coverage
|
.coverage\.*
|
||||||
|
|
||||||
# Created by https://www.toptal.com/developers/gitignore/api/linux,pycharm
|
# Created by https://www.toptal.com/developers/gitignore/api/linux,pycharm
|
||||||
# Edit at https://www.toptal.com/developers/gitignore?templates=linux,pycharm
|
# Edit at https://www.toptal.com/developers/gitignore?templates=linux,pycharm
|
||||||
@ -172,4 +174,4 @@ fabric.properties
|
|||||||
|
|
||||||
# End of https://www.toptal.com/developers/gitignore/api/linux,pycharm
|
# End of https://www.toptal.com/developers/gitignore/api/linux,pycharm
|
||||||
/image_prediction/data/mlruns/
|
/image_prediction/data/mlruns/
|
||||||
/data/mlruns/
|
#/data/mlruns/
|
||||||
|
|||||||
11
banner.txt
Normal file
11
banner.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
+----------------------------------------------------+
|
||||||
|
| ___ |
|
||||||
|
| __/_ `. .-"""-. |
|
||||||
|
|_._ _,-'""`-._ \_,` | \-' / )`-')|
|
||||||
|
|(,-.`._,'( |\`-/| "") `"` \ ((`"` |
|
||||||
|
| `-.-' \ )-`( , o o) ___Y , .'7 /| |
|
||||||
|
| `- \`_`"'- (_,___/...-` (_/_/ |
|
||||||
|
| |
|
||||||
|
+----------------------------------------------------+
|
||||||
|
| Image Classification Service |
|
||||||
|
+----------------------------------------------------+
|
||||||
@ -4,14 +4,14 @@ webserver:
|
|||||||
mode: $SERVER_MODE|production # webserver mode: {development, production}
|
mode: $SERVER_MODE|production # webserver mode: {development, production}
|
||||||
|
|
||||||
service:
|
service:
|
||||||
logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger
|
logging_level: INFO # Logging level for service logger
|
||||||
progressbar: True # Whether a progress bar over the pages of a document is displayed while processing
|
progressbar: True # Whether a progress bar over the pages of a document is displayed while processing
|
||||||
batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously
|
batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously
|
||||||
verbose: $VERBOSE|True # Service prints document processing progress to stdout
|
verbose: $VERBOSE|True # Service prints document processing progress to stdout
|
||||||
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the model from
|
run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the service_estimator from
|
||||||
|
|
||||||
|
|
||||||
# These variables control filters that are applied to either images, image metadata or model predictions. The filter
|
# These variables control filters that are applied to either images, image metadata or service_estimator predictions. The filter
|
||||||
# result values are reported in the service responses. For convenience the response to a request contains a
|
# result values are reported in the service responses. For convenience the response to a request contains a
|
||||||
# "filters.allPassed" field, which is set to false if any of the filters returned values did not meet its specified
|
# "filters.allPassed" field, which is set to false if any of the filters returned values did not meet its specified
|
||||||
# required value.
|
# required value.
|
||||||
|
|||||||
1
data/.gitignore
vendored
Normal file
1
data/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/mlruns
|
||||||
@ -1,4 +0,0 @@
|
|||||||
outs:
|
|
||||||
- md5: 6d0186c1f25e889d531788f168fa6cf0
|
|
||||||
size: 16727296
|
|
||||||
path: base_weights.h5
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
outs:
|
outs:
|
||||||
- md5: d1c708270bab6fcd344d4a8b05d1103d.dir
|
- md5: 4219c52caf5f87f5a94f1ae00c60fb91.dir
|
||||||
size: 150225383
|
size: 166952679
|
||||||
nfiles: 178
|
nfiles: 179
|
||||||
path: mlruns
|
path: mlruns
|
||||||
|
|||||||
1
doc/tests.drawio
Normal file
1
doc/tests.drawio
Normal file
@ -0,0 +1 @@
|
|||||||
|
<mxfile host="app.diagrams.net" modified="2022-03-17T15:35:10.371Z" agent="5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.51 Safari/537.36" etag="b-CbBXg6FXQ9T3Px-oLc" version="17.1.1" type="device"><diagram id="tS3WR_Pr6QhNVK3FqSUP" name="Page-1">1ZZRT6QwEMc/DY8mQHdRX93z9JLbmNzGmNxbQ0daLQzpDrL46a/IsCzinneJcd0XaP+dtsN/fkADscg3V06WeokKbBCHahOIb0Ecnydzf22FphPmyXknZM6oTooGYWWegcWQ1cooWI8CCdGSKcdiikUBKY006RzW47B7tONdS5nBRFil0k7VO6NId+rZPBz0azCZ7neOQh7JZR/MwlpLhfWOJC4DsXCI1LXyzQJs613vSzfv+57RbWIOCvqXCZqW9PBref27aZ7xsQ5vTn/cnvAqT9JW/MCwJuNzR8dZU9Nb4bAqFLSrhYG4qLUhWJUybUdrX3uvacqt70W+yeuCI9jsTTja2uDxAcyBXONDeILonWN04hn366EQUR+jd4qQsCa59tl26cEe32CH/sOt+TueoCONGRbS/kQs2YkHIGoYbFkRvuUTqAmFr1zyu2LlUvhLdjG/HtJlQO/VfOq6AyvJPI3z+HAL4wlwpbp/2V0qODxzUTJmLjo4c8nEkxaWFXcLLPzt4ithKI4BQzHBMOc/l8UvAeLrj9/hQTw9NhBnxwDibB+IB+ZvdvZ5/PnucAx6Gds5S4rLPw==</diagram></mxfile>
|
||||||
34
image_prediction/classifier/classifier.py
Normal file
34
image_prediction/classifier/classifier.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from typing import List, Union, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL.Image import Image
|
||||||
|
from funcy import rcompose
|
||||||
|
|
||||||
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
|
from image_prediction.label_mapper.mapper import LabelMapper
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class Classifier:
|
||||||
|
def __init__(self, estimator_adapter: EstimatorAdapter, label_mapper: LabelMapper):
|
||||||
|
"""Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used
|
||||||
|
an EstimatorAdapter must be implemented.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
estimator_adapter: adapter for a given estimator backend
|
||||||
|
"""
|
||||||
|
self.__estimator_adapter = estimator_adapter
|
||||||
|
self.__label_mapper = label_mapper
|
||||||
|
self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper)
|
||||||
|
|
||||||
|
def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]:
|
||||||
|
if not isinstance(batch, tuple) and batch.shape[0] == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return list(self.__pipe(batch))
|
||||||
|
|
||||||
|
def __call__(self, batch: np.array) -> List[str]:
|
||||||
|
logger.debug("Classifier.predict")
|
||||||
|
return self.predict(batch)
|
||||||
32
image_prediction/classifier/image_classifier.py
Normal file
32
image_prediction/classifier/image_classifier.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from PIL.Image import Image
|
||||||
|
from funcy import rcompose, chunks
|
||||||
|
|
||||||
|
from image_prediction.classifier.classifier import Classifier
|
||||||
|
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||||
|
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageClassifier:
|
||||||
|
"""Combines a classifier with a preprocessing pipeline: Receives images, chunks into batches, converts to tensors,
|
||||||
|
applies transformations and finally sends to internal classifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None):
|
||||||
|
self.estimator = classifier
|
||||||
|
self.preprocessor = preprocessor if preprocessor else IdentityPreprocessor()
|
||||||
|
self.pipe = rcompose(self.preprocessor, self.estimator)
|
||||||
|
|
||||||
|
def predict(self, images: Iterable[Image], batch_size=16):
|
||||||
|
batches = chunks(batch_size, images)
|
||||||
|
predictions = chain.from_iterable(map(self.pipe, batches))
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
def __call__(self, images: Iterable[Image], batch_size=16):
|
||||||
|
logger.debug("ImageClassifier.predict")
|
||||||
|
yield from self.predict(images, batch_size=batch_size)
|
||||||
0
image_prediction/compositor/__init__.py
Normal file
0
image_prediction/compositor/__init__.py
Normal file
16
image_prediction/compositor/compositor.py
Normal file
16
image_prediction/compositor/compositor.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from funcy import rcompose
|
||||||
|
|
||||||
|
from image_prediction.transformer.transformer import Transformer
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerCompositor(Transformer):
|
||||||
|
def __init__(self, formatter: Transformer, *formatters: Transformer):
|
||||||
|
formatters = (formatter, *formatters)
|
||||||
|
self.pipe = rcompose(*formatters)
|
||||||
|
|
||||||
|
def transform(self, obj):
|
||||||
|
logger.debug("TransformerCompositor.transform")
|
||||||
|
return self.pipe(obj)
|
||||||
@ -18,12 +18,12 @@ class DotIndexable:
|
|||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return _get_item_and_maybe_make_dotindexable(self.x, item)
|
return _get_item_and_maybe_make_dotindexable(self.x, item)
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
self.x[key] = value
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.x.__repr__()
|
return self.x.__repr__()
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.__getattr__(item)
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
def __init__(self, config_path):
|
def __init__(self, config_path):
|
||||||
|
|||||||
47
image_prediction/default_objects.py
Normal file
47
image_prediction/default_objects.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from funcy import juxt
|
||||||
|
|
||||||
|
from image_prediction.classifier.classifier import Classifier
|
||||||
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
|
from image_prediction.compositor.compositor import TransformerCompositor
|
||||||
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
|
from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier
|
||||||
|
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
||||||
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||||
|
from image_prediction.transformer.transformers.response import ResponseTransformer
|
||||||
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper
|
||||||
|
from image_prediction.model_loader.loader import ModelLoader
|
||||||
|
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||||
|
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||||
|
from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer
|
||||||
|
|
||||||
|
|
||||||
|
def get_mlflow_model_loader(mlruns_dir):
|
||||||
|
model_loader = ModelLoader(MlflowConnector(MlflowModelReader(mlruns_dir)))
|
||||||
|
return model_loader
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_classifier(model_loader, model_identifier):
|
||||||
|
model, classes = juxt(model_loader.load_model, model_loader.load_classes)(model_identifier)
|
||||||
|
return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_extractor(**kwargs):
|
||||||
|
image_extractor = ParsablePDFImageExtractor(**kwargs)
|
||||||
|
|
||||||
|
return image_extractor
|
||||||
|
|
||||||
|
|
||||||
|
def get_extractor_classifier(model_loader, model_identifier, **kwargs):
|
||||||
|
extractor_classifier = ExtractorClassifier(
|
||||||
|
get_extractor(**kwargs), get_image_classifier(model_loader, model_identifier)
|
||||||
|
)
|
||||||
|
|
||||||
|
return extractor_classifier
|
||||||
|
|
||||||
|
|
||||||
|
def get_formatter():
|
||||||
|
formatter = TransformerCompositor(
|
||||||
|
PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()
|
||||||
|
)
|
||||||
|
return formatter
|
||||||
0
image_prediction/estimator/__init__.py
Normal file
0
image_prediction/estimator/__init__.py
Normal file
0
image_prediction/estimator/adapter/__init__.py
Normal file
0
image_prediction/estimator/adapter/__init__.py
Normal file
15
image_prediction/estimator/adapter/adapter.py
Normal file
15
image_prediction/estimator/adapter/adapter.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class EstimatorAdapter:
|
||||||
|
def __init__(self, estimator):
|
||||||
|
self.estimator = estimator
|
||||||
|
|
||||||
|
def predict(self, batch):
|
||||||
|
return self.estimator(batch)
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
logger.debug("EstimatorAdapter.predict")
|
||||||
|
return self.predict(batch)
|
||||||
0
image_prediction/estimator/preprocessor/__init__.py
Normal file
0
image_prediction/estimator/preprocessor/__init__.py
Normal file
10
image_prediction/estimator/preprocessor/preprocessor.py
Normal file
10
image_prediction/estimator/preprocessor/preprocessor.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class Preprocessor(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def preprocess(self, batch):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
return self.preprocess(batch)
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||||
|
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BasicPreprocessor(Preprocessor):
|
||||||
|
"""Converts images to tensors"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def preprocess(images):
|
||||||
|
return images_to_batch_tensor(images)
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityPreprocessor(Preprocessor):
|
||||||
|
@staticmethod
|
||||||
|
def preprocess(images):
|
||||||
|
return images
|
||||||
|
|
||||||
|
def __call__(self, images):
|
||||||
|
return self.preprocess(images)
|
||||||
10
image_prediction/estimator/preprocessor/utils.py
Normal file
10
image_prediction/estimator/preprocessor/utils.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import numpy as np
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_normalized_tensor(image: Image) -> np.ndarray:
|
||||||
|
return np.array(image) / 255
|
||||||
|
|
||||||
|
|
||||||
|
def images_to_batch_tensor(images) -> np.ndarray:
|
||||||
|
return np.array(list(map(image_to_normalized_tensor, images)))
|
||||||
34
image_prediction/exceptions.py
Normal file
34
image_prediction/exceptions.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
class UnknownEstimatorAdapter(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownImageExtractor(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownModelLoader(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownDatabaseType(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownLabelFormat(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedLabelFormat(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IncorrectInstantiation(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IntentionalTestException(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidBox(Exception):
|
||||||
|
pass
|
||||||
13
image_prediction/extraction.py
Normal file
13
image_prediction/extraction.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
|
|
||||||
|
|
||||||
|
def extract_images_from_pdf(pdf, extractor=None):
|
||||||
|
|
||||||
|
if not extractor:
|
||||||
|
extractor = ParsablePDFImageExtractor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
images_extracted, metadata_extracted = zip(*extractor(pdf))
|
||||||
|
return images_extracted, metadata_extracted
|
||||||
|
except ValueError:
|
||||||
|
return [], []
|
||||||
0
image_prediction/extractor_classifier/__init__.py
Normal file
0
image_prediction/extractor_classifier/__init__.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from funcy import chunks
|
||||||
|
|
||||||
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
|
from image_prediction.image_extractor.extractor import ImageExtractor
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractorClassifier:
|
||||||
|
"""This class is responsible for orchestrating the pairing of classifications and image metadata. It extracts images
|
||||||
|
from an object and classifies them. Then it ties the classification together with the metadata. It returns an
|
||||||
|
iterable of dictionaries, where each dictionary has a field 'label' for the classification and possibly additional
|
||||||
|
fields for metadata -- metadata could be void.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_extractor: ImageExtractor, image_classifier: ImageClassifier):
|
||||||
|
self.classifier = image_classifier
|
||||||
|
self.extractor = image_extractor
|
||||||
|
|
||||||
|
def __process_batch(self, batch):
|
||||||
|
images, metadata = zip(*batch)
|
||||||
|
|
||||||
|
predictions = self.classifier(images)
|
||||||
|
responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata))
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def __call__(self, obj, **kwargs) -> Iterable[dict]:
|
||||||
|
image_metadata_pairs = self.extractor(obj, **kwargs)
|
||||||
|
batches = chunks(16, image_metadata_pairs)
|
||||||
|
predictions = chain.from_iterable(map(self.__process_batch, batches))
|
||||||
|
return predictions
|
||||||
@ -1,4 +1,5 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import traceback
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from flask import Flask, request, jsonify
|
from flask import Flask, request, jsonify
|
||||||
@ -8,8 +9,30 @@ from image_prediction.utils import get_logger
|
|||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def make_prediction_server(predict_fn: Callable):
|
def run_in_process(func):
|
||||||
|
p = multiprocessing.Process(target=func)
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_in_process(func_to_wrap):
|
||||||
|
def build_function_and_run_in_process(*args, **kwargs):
|
||||||
|
def func():
|
||||||
|
try:
|
||||||
|
result = func_to_wrap(*args, **kwargs)
|
||||||
|
return_dict["result"] = result
|
||||||
|
except:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
manager = multiprocessing.Manager()
|
||||||
|
return_dict = manager.dict()
|
||||||
|
run_in_process(func)
|
||||||
|
return return_dict.get("result", None)
|
||||||
|
|
||||||
|
return build_function_and_run_in_process
|
||||||
|
|
||||||
|
|
||||||
|
def make_prediction_server(predict_fn: Callable):
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
@app.route("/ready", methods=["GET"])
|
@app.route("/ready", methods=["GET"])
|
||||||
@ -24,42 +47,27 @@ def make_prediction_server(predict_fn: Callable):
|
|||||||
resp.status_code = 200
|
resp.status_code = 200
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@app.route("/", methods=["POST"])
|
def __failure():
|
||||||
def predict():
|
response = jsonify("Analysis failed")
|
||||||
def predict_fn_wrapper(pdf, return_dict):
|
|
||||||
return_dict["result"] = predict_fn(pdf)
|
|
||||||
|
|
||||||
def process():
|
|
||||||
# Tensorflow does not free RAM. Workaround is running model in process.
|
|
||||||
# https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
|
|
||||||
pdf = request.data
|
|
||||||
manager = multiprocessing.Manager()
|
|
||||||
return_dict = manager.dict()
|
|
||||||
p = multiprocessing.Process(
|
|
||||||
target=predict_fn_wrapper,
|
|
||||||
args=(
|
|
||||||
pdf,
|
|
||||||
return_dict,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
p.start()
|
|
||||||
p.join()
|
|
||||||
try:
|
|
||||||
return dict(return_dict)["result"]
|
|
||||||
except KeyError:
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.debug("Running predictor on document...")
|
|
||||||
try:
|
|
||||||
predictions = process()
|
|
||||||
response = jsonify(predictions)
|
|
||||||
logger.info("Analysis completed.")
|
|
||||||
return response
|
|
||||||
except Exception as err:
|
|
||||||
logger.error("Analysis failed.")
|
|
||||||
logger.exception(err)
|
|
||||||
response = jsonify("Analysis failed.")
|
|
||||||
response.status_code = 500
|
response.status_code = 500
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@app.route("/predict", methods=["POST"])
|
||||||
|
def predict():
|
||||||
|
|
||||||
|
# Tensorflow does not free RAM. Workaround: Run prediction function (which instantiates a model) in sub-process.
|
||||||
|
# See: https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
|
||||||
|
predict_fn_wrapped = wrap_in_process(predict_fn)
|
||||||
|
|
||||||
|
logger.info("Analysing...")
|
||||||
|
predictions = predict_fn_wrapped(request.data)
|
||||||
|
|
||||||
|
if predictions:
|
||||||
|
response = jsonify(predictions)
|
||||||
|
logger.info("Analysis completed.")
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
logger.error("Analysis failed.")
|
||||||
|
return __failure()
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|||||||
0
image_prediction/formatter/__init__.py
Normal file
0
image_prediction/formatter/__init__.py
Normal file
15
image_prediction/formatter/formatter.py
Normal file
15
image_prediction/formatter/formatter.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
from image_prediction.transformer.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
class Formatter(Transformer):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def format(self, obj):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def transform(self, obj):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __call__(self, obj):
|
||||||
|
return self.format(obj)
|
||||||
0
image_prediction/formatter/formatters/__init__.py
Normal file
0
image_prediction/formatter/formatters/__init__.py
Normal file
11
image_prediction/formatter/formatters/camel_case.py
Normal file
11
image_prediction/formatter/formatters/camel_case.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from image_prediction.formatter.formatters.key_formatter import KeyFormatter
|
||||||
|
|
||||||
|
|
||||||
|
class Snake2CamelCaseKeyFormatter(KeyFormatter):
|
||||||
|
def format_key(self, key):
|
||||||
|
|
||||||
|
if isinstance(key, str):
|
||||||
|
head, *tail = key.split("_")
|
||||||
|
return head + "".join(map(str.title, tail))
|
||||||
|
else:
|
||||||
|
return key
|
||||||
23
image_prediction/formatter/formatters/enum.py
Normal file
23
image_prediction/formatter/formatters/enum.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from image_prediction.formatter.formatters.key_formatter import KeyFormatter
|
||||||
|
|
||||||
|
|
||||||
|
class EnumFormatter(KeyFormatter):
|
||||||
|
def format_key(self, key):
|
||||||
|
return key.value if isinstance(key, Enum) else key
|
||||||
|
|
||||||
|
def transform(self, obj):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ReverseEnumFormatter(KeyFormatter):
|
||||||
|
def __init__(self, enum):
|
||||||
|
self.enum = enum
|
||||||
|
self.reverse_enum = {e.value: e for e in enum}
|
||||||
|
|
||||||
|
def format_key(self, key):
|
||||||
|
return self.reverse_enum.get(key, key)
|
||||||
|
|
||||||
|
def transform(self, obj):
|
||||||
|
raise NotImplementedError
|
||||||
6
image_prediction/formatter/formatters/identity.py
Normal file
6
image_prediction/formatter/formatters/identity.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from image_prediction.formatter.formatter import Formatter
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityFormatter(Formatter):
|
||||||
|
def format(self, obj):
|
||||||
|
return obj
|
||||||
28
image_prediction/formatter/formatters/key_formatter.py
Normal file
28
image_prediction/formatter/formatters/key_formatter.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from image_prediction.formatter.formatter import Formatter
|
||||||
|
|
||||||
|
|
||||||
|
class KeyFormatter(Formatter):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def format_key(self, key):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __format(self, data):
|
||||||
|
|
||||||
|
# If we wanted to do this properly, we would need handlers for all expected types and dispatch based
|
||||||
|
# on a type comparison. This is too much engineering for the limited use-case of this class though.
|
||||||
|
if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str):
|
||||||
|
f = map(self.__format, data)
|
||||||
|
return type(data)(f) if not isinstance(data, map) else f
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
|
||||||
|
keys_formatted = list(map(self.format_key, data))
|
||||||
|
|
||||||
|
return dict(zip(keys_formatted, map(self.__format, data.values())))
|
||||||
|
|
||||||
|
def format(self, data):
|
||||||
|
return self.__format(data)
|
||||||
0
image_prediction/image_extractor/__init__.py
Normal file
0
image_prediction/image_extractor/__init__.py
Normal file
19
image_prediction/image_extractor/extractor.py
Normal file
19
image_prediction/image_extractor/extractor.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import abc
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"])
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExtractor(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def extract(self, obj) -> Iterable[ImageMetadataPair]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, obj, **kwargs):
|
||||||
|
logger.debug("ImageExtractor.extract")
|
||||||
|
return self.extract(obj, **kwargs)
|
||||||
7
image_prediction/image_extractor/extractors/mock.py
Normal file
7
image_prediction/image_extractor/extractors/mock.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExtractorMock(ImageExtractor):
|
||||||
|
def extract(self, image_container):
|
||||||
|
for i, image in enumerate(image_container):
|
||||||
|
yield ImageMetadataPair(image, {"image_id": i})
|
||||||
181
image_prediction/image_extractor/extractors/parsable.py
Normal file
181
image_prediction/image_extractor/extractors/parsable.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
import atexit
|
||||||
|
import io
|
||||||
|
from functools import partial, lru_cache
|
||||||
|
from itertools import chain, starmap, filterfalse, repeat
|
||||||
|
from operator import itemgetter
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import fitz
|
||||||
|
from PIL import Image
|
||||||
|
from funcy import rcompose, merge, zipdict
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||||
|
from image_prediction.info import Info
|
||||||
|
from image_prediction.stitching.stitching import stitch_pairs
|
||||||
|
from image_prediction.stitching.utils import validate_box_coords, validate_box_size
|
||||||
|
|
||||||
|
|
||||||
|
class ParsablePDFImageExtractor(ImageExtractor):
|
||||||
|
def __init__(self, verbose=False, tolerance=0):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
verbose: Whether to show progressbar
|
||||||
|
tolerance: The tolerance in pixels for the distance images beyond which they will not be stitched together
|
||||||
|
"""
|
||||||
|
self.doc: fitz.fitz.Document = None
|
||||||
|
self.verbose = verbose
|
||||||
|
self.tolerance = tolerance
|
||||||
|
|
||||||
|
def extract(self, pdf: bytes, page_range: range = None):
|
||||||
|
self.doc = fitz.Document(stream=pdf)
|
||||||
|
|
||||||
|
pages = extract_pages(self.doc, page_range) if page_range else self.doc
|
||||||
|
|
||||||
|
image_metadata_pairs = chain.from_iterable(
|
||||||
|
map(
|
||||||
|
self.__process_images_on_page,
|
||||||
|
tqdm(pages, desc="Extracting", disable=not self.verbose, total=len(page_range) if page_range else None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield from image_metadata_pairs
|
||||||
|
|
||||||
|
def __process_images_on_page(self, page: fitz.fitz.Page):
|
||||||
|
images = get_images_on_page(self.doc, page)
|
||||||
|
metadata = get_metadata_for_images_on_page(self.doc, page)
|
||||||
|
clear_caches()
|
||||||
|
|
||||||
|
image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata)))
|
||||||
|
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
|
||||||
|
|
||||||
|
yield from image_metadata_pairs
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pages(doc, page_range):
|
||||||
|
page_range = range(page_range.start + 1, page_range.stop + 1)
|
||||||
|
pages = map(doc.load_page, page_range)
|
||||||
|
|
||||||
|
return pages
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_images_on_page(doc, page: fitz.Page):
|
||||||
|
image_infos = get_image_infos(page)
|
||||||
|
xrefs = map(itemgetter("xref"), image_infos)
|
||||||
|
images = map(partial(xref_to_image, doc), xrefs)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def get_metadata_for_images_on_page(doc, page: fitz.Page):
|
||||||
|
|
||||||
|
metadata = map(get_image_metadata, get_image_infos(page))
|
||||||
|
metadata = validate_coords_and_passthrough(metadata)
|
||||||
|
|
||||||
|
metadata = filter_out_tiny_images(metadata)
|
||||||
|
metadata = validate_size_and_passthrough(metadata)
|
||||||
|
|
||||||
|
metadata = add_page_metadata(page, metadata)
|
||||||
|
|
||||||
|
metadata = add_alpha_channel_info(doc, page, metadata)
|
||||||
|
|
||||||
|
yield from metadata
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_image_infos(page: fitz.Page) -> List[dict]:
|
||||||
|
return page.get_image_info(xrefs=True)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def xref_to_image(doc, xref) -> Image:
|
||||||
|
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||||
|
return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_metadata(image_info):
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = map(rounder, image_info["bbox"])
|
||||||
|
|
||||||
|
width = abs(x2 - x1)
|
||||||
|
height = abs(y2 - y1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
Info.WIDTH: width,
|
||||||
|
Info.HEIGHT: height,
|
||||||
|
Info.X1: x1,
|
||||||
|
Info.X2: x2,
|
||||||
|
Info.Y1: y1,
|
||||||
|
Info.Y2: y2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_coords_and_passthrough(metadata):
|
||||||
|
yield from map(validate_box_coords, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_out_tiny_images(metadata):
|
||||||
|
return filterfalse(tiny, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_size_and_passthrough(metadata):
|
||||||
|
yield from map(validate_box_size, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def add_page_metadata(page, metadata):
|
||||||
|
return map(partial(merge, get_page_metadata(page)), metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def add_alpha_channel_info(doc, page, metadata):
|
||||||
|
xrefs = map(itemgetter("xref"), get_image_infos(page))
|
||||||
|
alpha = map(partial(has_alpha_channel, doc), xrefs)
|
||||||
|
alpha = ({Info.ALPHA: a} for a in alpha)
|
||||||
|
# alpha = map(dict, zip(repeat(Info.ALPHA), alpha))
|
||||||
|
metadata = starmap(merge, zip(alpha, metadata))
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def load_image_handle_from_xref(doc, xref):
|
||||||
|
return doc.extract_image(xref)
|
||||||
|
|
||||||
|
|
||||||
|
rounder = rcompose(round, int)
|
||||||
|
|
||||||
|
|
||||||
|
def get_page_metadata(page):
|
||||||
|
page_width, page_height = map(rounder, page.mediabox_size)
|
||||||
|
|
||||||
|
return {
|
||||||
|
Info.PAGE_WIDTH: page_width,
|
||||||
|
Info.PAGE_HEIGHT: page_height,
|
||||||
|
Info.PAGE_IDX: page.number,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def has_alpha_channel(doc, xref):
|
||||||
|
|
||||||
|
maybe_image = load_image_handle_from_xref(doc, xref)
|
||||||
|
maybe_smask = maybe_image["smask"] if maybe_image else None
|
||||||
|
|
||||||
|
if maybe_smask:
|
||||||
|
return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)])
|
||||||
|
else:
|
||||||
|
return bool(fitz.Pixmap(doc, xref).alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def tiny(metadata):
|
||||||
|
return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4
|
||||||
|
|
||||||
|
|
||||||
|
def clear_caches():
|
||||||
|
get_image_infos.cache_clear()
|
||||||
|
load_image_handle_from_xref.cache_clear()
|
||||||
|
get_images_on_page.cache_clear()
|
||||||
|
xref_to_image.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(clear_caches)
|
||||||
14
image_prediction/info.py
Normal file
14
image_prediction/info.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class Info(Enum):
|
||||||
|
PAGE_WIDTH = "page_width"
|
||||||
|
PAGE_HEIGHT = "page_height"
|
||||||
|
PAGE_IDX = "page_idx"
|
||||||
|
WIDTH = "width"
|
||||||
|
HEIGHT = "height"
|
||||||
|
X1 = "x1"
|
||||||
|
X2 = "x2"
|
||||||
|
Y1 = "y1"
|
||||||
|
Y2 = "y2"
|
||||||
|
ALPHA = "alpha"
|
||||||
0
image_prediction/label_mapper/__init__.py
Normal file
0
image_prediction/label_mapper/__init__.py
Normal file
10
image_prediction/label_mapper/mapper.py
Normal file
10
image_prediction/label_mapper/mapper.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class LabelMapper(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def map_labels(self, items):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, items):
|
||||||
|
return self.map_labels(items)
|
||||||
0
image_prediction/label_mapper/mappers/__init__.py
Normal file
0
image_prediction/label_mapper/mappers/__init__.py
Normal file
20
image_prediction/label_mapper/mappers/numeric.py
Normal file
20
image_prediction/label_mapper/mappers/numeric.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Mapping, Iterable
|
||||||
|
|
||||||
|
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||||
|
from image_prediction.label_mapper.mapper import LabelMapper
|
||||||
|
|
||||||
|
|
||||||
|
class IndexMapper(LabelMapper):
|
||||||
|
def __init__(self, labels: Mapping[int, str]):
|
||||||
|
self.__labels = labels
|
||||||
|
|
||||||
|
def __validate_index_label_format(self, index_label: int) -> None:
|
||||||
|
if not 0 <= index_label < len(self.__labels):
|
||||||
|
raise UnexpectedLabelFormat(f"Received index label '{index_label}' that has no associated string label.")
|
||||||
|
|
||||||
|
def __map_label(self, index_label: int) -> str:
|
||||||
|
self.__validate_index_label_format(index_label)
|
||||||
|
return self.__labels[index_label]
|
||||||
|
|
||||||
|
def map_labels(self, index_labels: Iterable[int]) -> Iterable[str]:
|
||||||
|
return map(self.__map_label, index_labels)
|
||||||
39
image_prediction/label_mapper/mappers/probability.py
Normal file
39
image_prediction/label_mapper/mappers/probability.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from operator import itemgetter
|
||||||
|
from typing import Mapping, Iterable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from funcy import rcompose, rpartial
|
||||||
|
|
||||||
|
from image_prediction.exceptions import UnexpectedLabelFormat
|
||||||
|
from image_prediction.label_mapper.mapper import LabelMapper
|
||||||
|
|
||||||
|
|
||||||
|
class ProbabilityMapperKeys(Enum):
|
||||||
|
LABEL = "label"
|
||||||
|
PROBABILITIES = "probabilities"
|
||||||
|
|
||||||
|
|
||||||
|
class ProbabilityMapper(LabelMapper):
|
||||||
|
def __init__(self, labels: Mapping[int, str]):
|
||||||
|
self.__labels = labels
|
||||||
|
# String conversion in the middle due to floating point precision issues.
|
||||||
|
# See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly
|
||||||
|
self.__rounder = rcompose(rpartial(round, 4), str, float)
|
||||||
|
|
||||||
|
def __validate_array_label_format(self, probabilities: np.ndarray) -> None:
|
||||||
|
if not len(probabilities) == len(self.__labels):
|
||||||
|
raise UnexpectedLabelFormat(
|
||||||
|
f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __map_array(self, probabilities: np.ndarray) -> dict:
|
||||||
|
self.__validate_array_label_format(probabilities)
|
||||||
|
cls2prob = dict(
|
||||||
|
sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True)
|
||||||
|
)
|
||||||
|
most_likely = [*cls2prob][0]
|
||||||
|
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: cls2prob}
|
||||||
|
|
||||||
|
def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]:
|
||||||
|
return map(self.__map_array, probabilities)
|
||||||
@ -1,10 +1,17 @@
|
|||||||
from os import path
|
"""Defines constant paths relative to the module root path."""
|
||||||
|
|
||||||
MODULE_DIR = path.dirname(path.abspath(__file__))
|
from pathlib import Path
|
||||||
PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR)
|
|
||||||
|
|
||||||
CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml")
|
MODULE_DIR = Path(__file__).resolve().parents[0]
|
||||||
|
|
||||||
DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data")
|
PACKAGE_ROOT_DIR = MODULE_DIR.parents[0]
|
||||||
MLRUNS_DIR = path.join(DATA_DIR, "mlruns")
|
|
||||||
BASE_WEIGHTS = path.join(DATA_DIR, "base_weights.h5")
|
CONFIG_FILE = PACKAGE_ROOT_DIR / "config.yaml"
|
||||||
|
|
||||||
|
BANNER_FILE = PACKAGE_ROOT_DIR / "banner.txt"
|
||||||
|
|
||||||
|
DATA_DIR = PACKAGE_ROOT_DIR / "data"
|
||||||
|
|
||||||
|
MLRUNS_DIR = str(DATA_DIR / "mlruns")
|
||||||
|
|
||||||
|
TEST_DATA_DIR = PACKAGE_ROOT_DIR / "test" / "data"
|
||||||
|
|||||||
0
image_prediction/model_loader/__init__.py
Normal file
0
image_prediction/model_loader/__init__.py
Normal file
0
image_prediction/model_loader/database/__init__.py
Normal file
0
image_prediction/model_loader/database/__init__.py
Normal file
7
image_prediction/model_loader/database/connector.py
Normal file
7
image_prediction/model_loader/database/connector.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnector(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_object(self, identifier):
|
||||||
|
raise NotImplementedError
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
from image_prediction.model_loader.database.connector import DatabaseConnector
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseConnectorMock(DatabaseConnector):
|
||||||
|
def __init__(self, store: dict):
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
def get_object(self, identifier):
|
||||||
|
return self.store[identifier]
|
||||||
18
image_prediction/model_loader/loader.py
Normal file
18
image_prediction/model_loader/loader.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from image_prediction.model_loader.database.connector import DatabaseConnector
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLoader:
|
||||||
|
def __init__(self, database_connector: DatabaseConnector):
|
||||||
|
self.database_connector = database_connector
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def __get_object(self, identifier):
|
||||||
|
return self.database_connector.get_object(identifier)
|
||||||
|
|
||||||
|
def load_model(self, identifier):
|
||||||
|
return self.__get_object(identifier)["model"]
|
||||||
|
|
||||||
|
def load_classes(self, identifier):
|
||||||
|
return self.__get_object(identifier)["classes"]
|
||||||
0
image_prediction/model_loader/loaders/__init__.py
Normal file
0
image_prediction/model_loader/loaders/__init__.py
Normal file
10
image_prediction/model_loader/loaders/mlflow.py
Normal file
10
image_prediction/model_loader/loaders/mlflow.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from image_prediction.model_loader.database.connector import DatabaseConnector
|
||||||
|
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||||
|
|
||||||
|
|
||||||
|
class MlflowConnector(DatabaseConnector):
|
||||||
|
def __init__(self, mlflow_reader: MlflowModelReader):
|
||||||
|
self.mlflow_reader = mlflow_reader
|
||||||
|
|
||||||
|
def get_object(self, run_id):
|
||||||
|
return self.mlflow_reader[run_id]
|
||||||
26
image_prediction/pipeline.py
Normal file
26
image_prediction/pipeline.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from funcy import rcompose
|
||||||
|
|
||||||
|
from image_prediction.config import CONFIG
|
||||||
|
from image_prediction.default_objects import get_extractor_classifier, get_formatter, get_mlflow_model_loader
|
||||||
|
from image_prediction.locations import MLRUNS_DIR
|
||||||
|
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
|
||||||
|
def load_pipeline(**kwargs):
|
||||||
|
model_loader = get_mlflow_model_loader(MLRUNS_DIR)
|
||||||
|
model_identifier = CONFIG.service.run_id
|
||||||
|
|
||||||
|
pipeline = Pipeline(model_loader, model_identifier, **kwargs)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
def __init__(self, model_loader, model_identifier, **kwargs):
|
||||||
|
self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter())
|
||||||
|
|
||||||
|
def __call__(self, pdf: bytes, page_range: range = None):
|
||||||
|
yield from self.pipe(pdf, page_range=page_range)
|
||||||
@ -1,122 +0,0 @@
|
|||||||
from itertools import chain
|
|
||||||
from operator import itemgetter
|
|
||||||
from typing import List, Dict, Iterable
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
|
||||||
from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS
|
|
||||||
from image_prediction.utils import temporary_pdf_file, get_logger
|
|
||||||
from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle
|
|
||||||
from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch
|
|
||||||
from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader
|
|
||||||
from incl.redai_image.redai.redai.utils.shared import chunk_iterable
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class Predictor:
|
|
||||||
"""`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is
|
|
||||||
interpretable independently of the wrapped model (e.g. with regard to a .classes_ attribute).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model_handle: ModelHandle = None):
|
|
||||||
"""Initializes a ServiceEstimator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_handle: ModelHandle object to forward to for prediction. By default, a model handle is loaded from the
|
|
||||||
mlflow database via CONFIG.service.run_id.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if model_handle is None:
|
|
||||||
reader = MlflowModelReader(run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR)
|
|
||||||
self.model_handle = reader.get_model_handle(BASE_WEIGHTS)
|
|
||||||
else:
|
|
||||||
self.model_handle = model_handle
|
|
||||||
|
|
||||||
self.classes = self.model_handle.model.classes_
|
|
||||||
self.classes_readable = np.array(self.model_handle.classes)
|
|
||||||
self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]]
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"Service estimator initialization failed: {e}")
|
|
||||||
|
|
||||||
def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]:
|
|
||||||
"""Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to
|
|
||||||
probabilities.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
probs: probability matrix (items x classes)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of mappings from classes to probabilities.
|
|
||||||
"""
|
|
||||||
classes = np.argmax(probs, axis=1)
|
|
||||||
classes = self.classes[classes]
|
|
||||||
classes_readable = [self.model_handle.classes[c] for c in classes]
|
|
||||||
return classes_readable
|
|
||||||
|
|
||||||
def predict(self, images: List, probabilities: bool = False, **kwargs):
|
|
||||||
"""Gathers predictions for list of images. Assigns each image a class and optionally a probability distribution
|
|
||||||
over all classes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images (List[PIL.Image]) : Images to gather predictions for.
|
|
||||||
probabilities: Whether to return dictionaries of the following form instead of strings:
|
|
||||||
{
|
|
||||||
"class": predicted class,
|
|
||||||
"probabilities": {
|
|
||||||
"class 1" : class 1 probability,
|
|
||||||
"class 2" : class 2 probability,
|
|
||||||
...
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
By default the return value is a list of classes (meaningful class name strings). Alternatively a list of
|
|
||||||
dictionaries with an additional probability field for estimated class probabilities per image can be
|
|
||||||
returned.
|
|
||||||
"""
|
|
||||||
X = self.model_handle.prep_images(list(images))
|
|
||||||
|
|
||||||
probs_per_item = self.model_handle.model.predict_proba(X, **kwargs).astype(float)
|
|
||||||
classes = self.__make_predictions_human_readable(probs_per_item)
|
|
||||||
|
|
||||||
class2prob_per_item = [dict(zip(self.classes_readable_aligned, probs)) for probs in probs_per_item]
|
|
||||||
class2prob_per_item = [
|
|
||||||
dict(sorted(c2p.items(), key=itemgetter(1), reverse=True)) for c2p in class2prob_per_item
|
|
||||||
]
|
|
||||||
|
|
||||||
predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)]
|
|
||||||
|
|
||||||
return predictions if probabilities else classes
|
|
||||||
|
|
||||||
def predict_pdf(self, pdf, verbose=False):
|
|
||||||
with temporary_pdf_file(pdf) as pdf_path:
|
|
||||||
image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path, verbose=verbose)
|
|
||||||
return self.__predict_images(image_metadata_pairs)
|
|
||||||
|
|
||||||
def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size):
|
|
||||||
def process_chunk(chunk):
|
|
||||||
images, metadata = zip(*chunk)
|
|
||||||
predictions = self.predict(images, probabilities=True)
|
|
||||||
return predictions, metadata
|
|
||||||
|
|
||||||
def predict(image_metadata_pair_generator):
|
|
||||||
chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size)
|
|
||||||
return map(chain.from_iterable, zip(*map(process_chunk, chunks)))
|
|
||||||
|
|
||||||
try:
|
|
||||||
predictions, metadata = predict(image_metadata_pairs)
|
|
||||||
return predictions, metadata
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __extract_image_metadata_pairs(pdf_path: str, **kwargs):
|
|
||||||
def image_is_large_enough(metadata: dict):
|
|
||||||
x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata)
|
|
||||||
|
|
||||||
return abs(x1 - x2) > 2 and abs(y1 - y2) > 2
|
|
||||||
|
|
||||||
yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs)
|
|
||||||
0
image_prediction/redai_adapter/__init__.py
Normal file
0
image_prediction/redai_adapter/__init__.py
Normal file
45
image_prediction/redai_adapter/efficient_net_wrapper.py
Normal file
45
image_prediction/redai_adapter/efficient_net_wrapper.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from image_prediction.redai_adapter.model_wrapper import ModelWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientNetWrapper(ModelWrapper):
|
||||||
|
def __init__(self, classes, base_weights_path=None, weights_path=None):
|
||||||
|
self.__input_shape = (224, 224, 3)
|
||||||
|
super().__init__(classes=classes, base_weights_path=base_weights_path, weights_path=weights_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
return self.__input_shape
|
||||||
|
|
||||||
|
def _ModelWrapper__preprocess_tensor(self, tensor):
|
||||||
|
return tf.keras.applications.efficientnet.preprocess_input(tensor)
|
||||||
|
|
||||||
|
def _ModelWrapper__build(self, base_weights=None) -> tf.keras.models.Model:
|
||||||
|
input_img = tf.keras.layers.Input(shape=self.input_shape)
|
||||||
|
|
||||||
|
pretrained = tf.keras.applications.efficientnet.EfficientNetB0(
|
||||||
|
include_top=False, input_tensor=tf.keras.layers.Input(shape=self.input_shape), weights=base_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained.trainable = False
|
||||||
|
|
||||||
|
for layer in pretrained.layers:
|
||||||
|
layer.trainable = False
|
||||||
|
|
||||||
|
pretrained = pretrained(input_img)
|
||||||
|
|
||||||
|
finetuned = tf.keras.layers.Flatten()(pretrained)
|
||||||
|
finetuned = tf.keras.layers.Dense(512, activation="relu")(finetuned)
|
||||||
|
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||||
|
finetuned = tf.keras.layers.Dense(128, activation="relu")(finetuned)
|
||||||
|
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||||
|
finetuned = tf.keras.layers.Dense(32, activation="relu")(finetuned)
|
||||||
|
finetuned = tf.keras.layers.Dropout(0.2)(finetuned)
|
||||||
|
finetuned = tf.keras.layers.Dense(len(self.classes), activation="softmax")(finetuned)
|
||||||
|
|
||||||
|
model = tf.keras.models.Model(inputs=input_img, outputs=finetuned)
|
||||||
|
|
||||||
|
model.compile()
|
||||||
|
|
||||||
|
return model
|
||||||
72
image_prediction/redai_adapter/mlflow.py
Normal file
72
image_prediction/redai_adapter/mlflow.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
|
||||||
|
from image_prediction.redai_adapter.model import PredictionModelHandle
|
||||||
|
|
||||||
|
|
||||||
|
class MlflowModelReader:
|
||||||
|
def __init__(self, mlruns_dir=None):
|
||||||
|
self.mlruns_dir = mlruns_dir
|
||||||
|
mlflow.set_tracking_uri(self.mlruns_dir)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __correct_artifact_uri(run_artifact_uri, base_path):
|
||||||
|
_, suffix = run_artifact_uri.split("mlruns/")
|
||||||
|
return os.path.join(base_path, suffix)
|
||||||
|
|
||||||
|
def __get_weights_path(self, run_id, prefix="tt"):
|
||||||
|
run = self.__get_run(run_id)
|
||||||
|
|
||||||
|
artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir)
|
||||||
|
path = os.path.join(artifact_uri, prefix, "train_dev", "estimator")
|
||||||
|
|
||||||
|
base_path = os.path.join(path, "base_weights.h5")
|
||||||
|
weights_path = os.path.join(path, "weights.h5")
|
||||||
|
|
||||||
|
return base_path, weights_path
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def __get_run(self, run_id):
|
||||||
|
return mlflow.get_run(run_id)
|
||||||
|
|
||||||
|
def __get_classes(self, run_id, prefix="tt"):
|
||||||
|
run = self.__get_run(run_id)
|
||||||
|
|
||||||
|
classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"'))
|
||||||
|
|
||||||
|
return classes
|
||||||
|
|
||||||
|
def __get_model_handle(self, run_id):
|
||||||
|
run = self.__get_run(run_id)
|
||||||
|
|
||||||
|
model_handle_builder = load_object(run.data.params["model_handle_builder"].strip())
|
||||||
|
|
||||||
|
base_weights_path, weights_path = self.__get_weights_path(run_id)
|
||||||
|
|
||||||
|
model_handle = model_handle_builder(
|
||||||
|
self.__get_classes(run_id), base_weights_path=base_weights_path, weights_path=weights_path
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_handle
|
||||||
|
|
||||||
|
def __get_model(self, run_id) -> PredictionModelHandle:
|
||||||
|
model_handle = self.__get_model_handle(run_id)
|
||||||
|
model = PredictionModelHandle(model_handle)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def __getitem__(self, run_id):
|
||||||
|
return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)}
|
||||||
|
|
||||||
|
|
||||||
|
def load_object(object_path):
|
||||||
|
path_fragments = object_path.split(".")
|
||||||
|
|
||||||
|
module_path = ".".join(path_fragments[:-1])
|
||||||
|
object_name = path_fragments[-1]
|
||||||
|
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
return getattr(module, object_name)
|
||||||
19
image_prediction/redai_adapter/model.py
Normal file
19
image_prediction/redai_adapter/model.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from funcy import rcompose
|
||||||
|
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionModelHandle:
|
||||||
|
"""Simplifies usage of ModelHandle instances for prediction purposes."""
|
||||||
|
|
||||||
|
def __init__(self, model_handle):
|
||||||
|
self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict)
|
||||||
|
|
||||||
|
def predict(self, *args, **kwargs):
|
||||||
|
return self.__predict(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
logger.debug("PredictionModelHandle.predict")
|
||||||
|
return self.predict(*args, **kwargs)
|
||||||
42
image_prediction/redai_adapter/model_wrapper.py
Normal file
42
image_prediction/redai_adapter/model_wrapper.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWrapper(abc.ABC):
|
||||||
|
def __init__(self, classes, base_weights_path=None, weights_path=None):
|
||||||
|
self.__classes = classes
|
||||||
|
self.model = self.__build(base_weights_path)
|
||||||
|
self.model.load_weights(weights_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def input_shape(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classes(self):
|
||||||
|
return self.__classes
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __preprocess_tensor(self, tensor):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __images_to_tensor(images):
|
||||||
|
return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images)))
|
||||||
|
|
||||||
|
def __resize_and_convert(self, image):
|
||||||
|
return image.resize(self.input_shape[:-1]).convert("RGB")
|
||||||
|
|
||||||
|
def prep_images(self, images):
|
||||||
|
images = map(self.__resize_and_convert, images)
|
||||||
|
tensor = self.__images_to_tensor(images)
|
||||||
|
tensor = self.__preprocess_tensor(tensor)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __build(self, base_weights=None) -> tf.keras.models.Model:
|
||||||
|
raise NotImplementedError
|
||||||
0
image_prediction/stitching/__init__.py
Normal file
0
image_prediction/stitching/__init__.py
Normal file
63
image_prediction/stitching/grouping.py
Normal file
63
image_prediction/stitching/grouping.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
from itertools import groupby
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from funcy import compose, second
|
||||||
|
|
||||||
|
from image_prediction.stitching.utils import make_coord_getter
|
||||||
|
|
||||||
|
|
||||||
|
class CoordGrouper:
|
||||||
|
def __init__(self, axis, tolerance=0):
|
||||||
|
self.c1_getter = make_coord_getter(f"{other_axis(axis)}1")
|
||||||
|
self.c2_getter = make_coord_getter(f"{other_axis(axis)}2")
|
||||||
|
self.tolerance = tolerance
|
||||||
|
|
||||||
|
def group_pairs_by_lesser_coordinate(self, pairs):
|
||||||
|
return group_by_coordinate(pairs, self.c1_getter, self.tolerance)
|
||||||
|
|
||||||
|
def group_pairs_by_greater_coordinate(self, pairs):
|
||||||
|
return group_by_coordinate(pairs, self.c2_getter, self.tolerance)
|
||||||
|
|
||||||
|
|
||||||
|
def other_axis(axis):
|
||||||
|
return "y" if axis == "x" else "x"
|
||||||
|
|
||||||
|
|
||||||
|
def fuzzify(func, tolerance):
|
||||||
|
def inner(item):
|
||||||
|
nonlocal mid_points
|
||||||
|
nonlocal lower_bounds
|
||||||
|
nonlocal upper_bounds
|
||||||
|
|
||||||
|
value = func(item)
|
||||||
|
fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array()))
|
||||||
|
if any(fits):
|
||||||
|
return mid_points[np.argmax(fits)]
|
||||||
|
else:
|
||||||
|
mid_points = [*mid_points, value]
|
||||||
|
lower_bounds = [*lower_bounds, value - tolerance]
|
||||||
|
upper_bounds = [*upper_bounds, value + tolerance]
|
||||||
|
return value
|
||||||
|
|
||||||
|
def lower_bounds_array():
|
||||||
|
return tuple(lower_bounds)
|
||||||
|
|
||||||
|
def upper_bounds_array():
|
||||||
|
return tuple(upper_bounds)
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def array(tpl):
|
||||||
|
return np.array(tpl)
|
||||||
|
|
||||||
|
lower_bounds = []
|
||||||
|
upper_bounds = []
|
||||||
|
mid_points = []
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def group_by_coordinate(pairs, coord_getter, tolerance=0):
|
||||||
|
coord_getter = fuzzify(coord_getter, tolerance)
|
||||||
|
pairs = sorted(pairs, key=coord_getter)
|
||||||
|
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||||
174
image_prediction/stitching/merging.py
Normal file
174
image_prediction/stitching/merging.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Iterable, Callable, List
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from funcy import juxt, first, rest, rcompose, rpartial
|
||||||
|
|
||||||
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
|
from image_prediction.info import Info
|
||||||
|
from image_prediction.stitching.grouping import CoordGrouper
|
||||||
|
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
|
||||||
|
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once, validate_box
|
||||||
|
from image_prediction.utils.generic import until
|
||||||
|
|
||||||
|
|
||||||
|
def no_new_merges(pairs1, pairs2):
|
||||||
|
return len(pairs1) == len(pairs2)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
|
||||||
|
pairs = merge_along_axis(pairs, "x", tolerance=tolerance)
|
||||||
|
pairs = list(merge_along_axis(pairs, "y", tolerance=tolerance))
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> Iterable[ImageMetadataPair]:
|
||||||
|
"""Partially merges image-metadata pairs of adjacent images along a given axis. Needs to be iterated with
|
||||||
|
alternating axes until no more merges happen to merge all adjacent images.
|
||||||
|
|
||||||
|
Explanation:
|
||||||
|
|
||||||
|
Merging algorithm works as follows:
|
||||||
|
A dot represents a pair, a bracket a group and a colon a merged pair.
|
||||||
|
1) Start with pairs: (........)
|
||||||
|
2) Align on lesser: ([....] [....])
|
||||||
|
3) Align on greater: ([[..] [..]] [[....]])
|
||||||
|
4) Flatten once: ([..] [..] [....])
|
||||||
|
5) Merge orthogonally: ([:] [..] [:..])
|
||||||
|
6) Flatten once: (:..:..)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def group_pairs_within_groups_by_greater_coordinate(groups):
|
||||||
|
return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups)
|
||||||
|
|
||||||
|
def merge_groups_along_orthogonal_axis(groups):
|
||||||
|
return map(rpartial(make_group_merger(axis), tolerance), groups)
|
||||||
|
|
||||||
|
def group_pairs_by_lesser_coordinate(pairs):
|
||||||
|
return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs)
|
||||||
|
|
||||||
|
return rcompose(
|
||||||
|
group_pairs_by_lesser_coordinate,
|
||||||
|
group_pairs_within_groups_by_greater_coordinate,
|
||||||
|
flatten_groups_once,
|
||||||
|
merge_groups_along_orthogonal_axis,
|
||||||
|
flatten_groups_once,
|
||||||
|
)(pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_group_merger(axis):
|
||||||
|
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_group_vertically(group: Iterable[ImageMetadataPair], tolerance=0):
|
||||||
|
return merge_group(group, "y", tolerance=tolerance)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0):
|
||||||
|
return merge_group(group, "x", tolerance=tolerance)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0):
|
||||||
|
reduce_group = make_merger_aggregator(direction, tolerance=tolerance)
|
||||||
|
return until(no_new_merges, reduce_group, group)
|
||||||
|
|
||||||
|
|
||||||
|
def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
||||||
|
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
||||||
|
head H and aggregates non-adjacent in the tail T.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When tolerance > 0, the bounding box of the merged image no longer matches the bounding box of the mereged
|
||||||
|
metadata. This is intended behaviour, but might be not be expected by the caller.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
|
||||||
|
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
||||||
|
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
||||||
|
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
||||||
|
if abs(c2_getter(aggr) - c1_getter(pair)) <= tolerance:
|
||||||
|
aggr = pair_merger(aggr, pair)
|
||||||
|
return aggr, *non_aggr
|
||||||
|
else:
|
||||||
|
return aggr, pair, *non_aggr
|
||||||
|
|
||||||
|
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
|
||||||
|
# only in c1 -> c2 direction.
|
||||||
|
pairs = sorted(pairs, key=c1_getter)
|
||||||
|
head_pair, pairs = juxt(first, rest)(pairs)
|
||||||
|
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair]))
|
||||||
|
|
||||||
|
assert tolerance >= 0
|
||||||
|
|
||||||
|
c1_getter = make_coord_getter(f"{axis}1")
|
||||||
|
c2_getter = make_coord_getter(f"{axis}2")
|
||||||
|
pair_merger = make_pair_merger(axis)
|
||||||
|
|
||||||
|
return merger_aggregator
|
||||||
|
|
||||||
|
|
||||||
|
def make_pair_merger(axis):
|
||||||
|
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
||||||
|
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
|
||||||
|
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
|
||||||
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
||||||
|
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
|
||||||
|
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
|
||||||
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_metadata_vertically(m1: dict, m2: dict):
|
||||||
|
m1, m2 = map(VerticalSplitMapper, [m1, m2])
|
||||||
|
return merge_metadata(m1, m2)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_metadata_horizontally(m1: dict, m2: dict):
|
||||||
|
m1, m2 = map(HorizontalSplitMapper, [m1, m2])
|
||||||
|
return merge_metadata(m1, m2)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_metadata(m1: dict, m2: dict):
|
||||||
|
|
||||||
|
c1 = min(m1.c1, m2.c1)
|
||||||
|
c2 = max(m1.c2, m2.c2)
|
||||||
|
dim = abs(c2 - c1)
|
||||||
|
|
||||||
|
merged = deepcopy(m1)
|
||||||
|
merged.dim = dim
|
||||||
|
merged.c1 = c1
|
||||||
|
merged.c2 = c2
|
||||||
|
|
||||||
|
validate_box(merged.wrapped)
|
||||||
|
|
||||||
|
return merged.wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
|
||||||
|
return concat_images(im1, im2, metadata, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
|
||||||
|
return concat_images(im1, im2, metadata, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
||||||
|
|
||||||
|
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
|
||||||
|
|
||||||
|
images = [im1, im2]
|
||||||
|
|
||||||
|
offsets = 0, im1.size[axis], im_aggr.size[axis] - im2.size[axis]
|
||||||
|
|
||||||
|
for im, offset in zip(images, offsets):
|
||||||
|
box = (offset, 0) if not axis else (0, offset)
|
||||||
|
im_aggr.paste(im, box=box)
|
||||||
|
|
||||||
|
return im_aggr
|
||||||
40
image_prediction/stitching/split_mapper.py
Normal file
40
image_prediction/stitching/split_mapper.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import field, dataclass
|
||||||
|
from operator import attrgetter
|
||||||
|
|
||||||
|
from image_prediction.info import Info
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SplitMapper:
|
||||||
|
"""Manages access into a mapping M by indirection through a specified access mapping to achieve a common
|
||||||
|
interface between various M_i.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__access_mapping: dict
|
||||||
|
wrapped: dict
|
||||||
|
__wrapped: dict = field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
for k, v in self.__access_mapping.items():
|
||||||
|
setattr(self, k, self.__wrapped[v])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wrapped(self):
|
||||||
|
ret = deepcopy(self.__wrapped)
|
||||||
|
ret.update(dict(zip(self.__access_mapping.values(), attrgetter(*self.__access_mapping.keys())(self))))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@wrapped.setter
|
||||||
|
def wrapped(self, wrapped):
|
||||||
|
self.__wrapped = wrapped
|
||||||
|
|
||||||
|
|
||||||
|
class HorizontalSplitMapper(SplitMapper):
|
||||||
|
def __init__(self, wrapped: dict):
|
||||||
|
super().__init__({"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}, wrapped)
|
||||||
|
|
||||||
|
|
||||||
|
class VerticalSplitMapper(SplitMapper):
|
||||||
|
def __init__(self, wrapped: dict):
|
||||||
|
super().__init__({"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}, wrapped)
|
||||||
13
image_prediction/stitching/stitching.py
Normal file
13
image_prediction/stitching/stitching.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
from funcy import rpartial
|
||||||
|
|
||||||
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
|
from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges
|
||||||
|
from image_prediction.utils.generic import until
|
||||||
|
|
||||||
|
|
||||||
|
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
|
||||||
|
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
|
||||||
|
images."""
|
||||||
|
return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs)
|
||||||
67
image_prediction/stitching/utils.py
Normal file
67
image_prediction/stitching/utils.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import json
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
from image_prediction.exceptions import InvalidBox
|
||||||
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||||
|
from image_prediction.info import Info
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_groups_once(groups):
|
||||||
|
return chain.from_iterable(groups)
|
||||||
|
|
||||||
|
|
||||||
|
def make_coord_getter(c):
|
||||||
|
return {
|
||||||
|
"x1": make_getter(Info.X1),
|
||||||
|
"x2": make_getter(Info.X2),
|
||||||
|
"y1": make_getter(Info.Y1),
|
||||||
|
"y2": make_getter(Info.Y2),
|
||||||
|
}[c]
|
||||||
|
|
||||||
|
|
||||||
|
def make_getter(key):
|
||||||
|
def getter(pair):
|
||||||
|
return pair.metadata[key]
|
||||||
|
|
||||||
|
return getter
|
||||||
|
|
||||||
|
|
||||||
|
def make_length_getter(dim):
|
||||||
|
return {
|
||||||
|
"width": make_getter(Info.WIDTH),
|
||||||
|
"height": make_getter(Info.HEIGHT),
|
||||||
|
}[dim]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_box(box):
|
||||||
|
validate_box_coords(box)
|
||||||
|
validate_box_size(box)
|
||||||
|
return box
|
||||||
|
|
||||||
|
|
||||||
|
def validate_box_coords(box):
|
||||||
|
|
||||||
|
x_diff = box[Info.WIDTH] - (box[Info.X2] - box[Info.X1])
|
||||||
|
y_diff = box[Info.HEIGHT] - (box[Info.Y2] - box[Info.Y1])
|
||||||
|
|
||||||
|
if x_diff:
|
||||||
|
raise InvalidBox(f"Width and x-coordinates differ by {x_diff} units: {format_box(box)}")
|
||||||
|
if y_diff:
|
||||||
|
raise InvalidBox(f"Width and y-coordinates differ by {y_diff} units: {format_box(box)}")
|
||||||
|
|
||||||
|
return box
|
||||||
|
|
||||||
|
|
||||||
|
def validate_box_size(box):
|
||||||
|
|
||||||
|
if not box[Info.WIDTH]:
|
||||||
|
raise InvalidBox(f"Zero width box: {format_box(box)}")
|
||||||
|
|
||||||
|
if not box[Info.HEIGHT]:
|
||||||
|
raise InvalidBox(f"Zero height box: {format_box(box)}")
|
||||||
|
|
||||||
|
return box
|
||||||
|
|
||||||
|
|
||||||
|
def format_box(box):
|
||||||
|
return json.dumps(EnumFormatter()(box), indent=2)
|
||||||
0
image_prediction/transformer/__init__.py
Normal file
0
image_prediction/transformer/__init__.py
Normal file
20
image_prediction/transformer/transformer.py
Normal file
20
image_prediction/transformer/transformer.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from funcy import curry, identity
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def transform(self, obj):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, obj):
|
||||||
|
return self._apply(self.transform, obj)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _must_be_mapped_over(obj):
|
||||||
|
return isinstance(obj, Iterable) and not isinstance(obj, dict)
|
||||||
|
|
||||||
|
def _apply(self, func, obj):
|
||||||
|
return (curry(map) if self._must_be_mapped_over(obj) else identity)(func)(obj)
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
from image_prediction.transformer.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
class CoordinateTransformer(Transformer):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _forward(self, metadata):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _backward(self, metadata):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, metadata):
|
||||||
|
return self._apply(self._forward, metadata)
|
||||||
|
|
||||||
|
def backward(self, metadata):
|
||||||
|
return self._apply(self._backward, metadata)
|
||||||
|
|
||||||
|
def transform(self, metadata):
|
||||||
|
return self.forward(metadata)
|
||||||
10
image_prediction/transformer/transformers/coordinate/fitz.py
Normal file
10
image_prediction/transformer/transformers/coordinate/fitz.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from image_prediction.transformer.transformers.coordinate.coordinate_transformer import CoordinateTransformer
|
||||||
|
|
||||||
|
|
||||||
|
class FitzCoordinateTransformer(CoordinateTransformer):
|
||||||
|
def _forward(self, metadata: dict):
|
||||||
|
"""Fitz uses top left corner as origin; we take this as the reference coordinate system."""
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _backward(self, metadata: dict):
|
||||||
|
return self.forward(metadata)
|
||||||
10
image_prediction/transformer/transformers/coordinate/fpdf.py
Normal file
10
image_prediction/transformer/transformers/coordinate/fpdf.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from image_prediction.transformer.transformers.coordinate.coordinate_transformer import CoordinateTransformer
|
||||||
|
|
||||||
|
|
||||||
|
class FPDFCoordinateTransformer(CoordinateTransformer):
|
||||||
|
def _forward(self, metadata: dict):
|
||||||
|
"""FPDF uses top left corner as origin; we take this as the reference coordinate system."""
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _backward(self, metadata: dict):
|
||||||
|
return self.forward(metadata)
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
|
from funcy import omit
|
||||||
|
|
||||||
|
from image_prediction.info import Info
|
||||||
|
from image_prediction.transformer.transformers.coordinate.coordinate_transformer import CoordinateTransformer
|
||||||
|
|
||||||
|
|
||||||
|
class PDFNetCoordinateTransformer(CoordinateTransformer):
|
||||||
|
def _forward(self, metadata: dict):
|
||||||
|
"""PDFNet coordinate system origin is in the bottom left corner."""
|
||||||
|
y1, y2, page_height = itemgetter(Info.Y1, Info.Y2, Info.PAGE_HEIGHT)(metadata)
|
||||||
|
y1_t = page_height - y2
|
||||||
|
y2_t = page_height - y1
|
||||||
|
return {**omit(metadata, [Info.Y1, Info.Y2]), **{Info.Y1: y1_t, Info.Y2: y2_t}}
|
||||||
|
|
||||||
|
def _backward(self, metadata: dict):
|
||||||
|
return self.forward(metadata)
|
||||||
@ -1,18 +1,20 @@
|
|||||||
"""Defines functions for constructing service responses."""
|
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from itertools import starmap
|
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
|
from image_prediction.transformer.transformer import Transformer
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def build_response(predictions: list, metadata: list) -> list:
|
class ResponseTransformer(Transformer):
|
||||||
return list(starmap(build_image_info, zip(predictions, metadata)))
|
def transform(self, data):
|
||||||
|
logger.debug("ResponseTransformer.transform")
|
||||||
|
return build_image_info(data)
|
||||||
|
|
||||||
|
|
||||||
def build_image_info(prediction: dict, metadata: dict) -> dict:
|
def build_image_info(data: dict) -> dict:
|
||||||
def compute_geometric_quotient():
|
def compute_geometric_quotient():
|
||||||
page_area_sqrt = math.sqrt(abs(page_width * page_height))
|
page_area_sqrt = math.sqrt(abs(page_width * page_height))
|
||||||
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1))
|
||||||
@ -20,9 +22,9 @@ def build_image_info(prediction: dict, metadata: dict) -> dict:
|
|||||||
|
|
||||||
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
|
page_width, page_height, x1, x2, y1, y2, width, height = itemgetter(
|
||||||
"page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height"
|
"page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height"
|
||||||
)(metadata)
|
)(data)
|
||||||
|
|
||||||
quotient = compute_geometric_quotient()
|
quotient = round(compute_geometric_quotient(), 4)
|
||||||
|
|
||||||
min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min)
|
min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min)
|
||||||
max_image_to_page_quotient_breached = bool(quotient > CONFIG.filters.image_to_page_quotient.max)
|
max_image_to_page_quotient_breached = bool(quotient > CONFIG.filters.image_to_page_quotient.max)
|
||||||
@ -33,13 +35,13 @@ def build_image_info(prediction: dict, metadata: dict) -> dict:
|
|||||||
width / height > CONFIG.filters.image_width_to_height_quotient.max
|
width / height > CONFIG.filters.image_width_to_height_quotient.max
|
||||||
)
|
)
|
||||||
|
|
||||||
min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence)
|
classification = data["classification"]
|
||||||
prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper
|
|
||||||
prediction["probabilities"] = {klass: round(prob, 6) for klass, prob in prediction["probabilities"].items()}
|
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||||
|
|
||||||
image_info = {
|
image_info = {
|
||||||
"classification": prediction,
|
"classification": classification,
|
||||||
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": metadata["page_idx"] + 1},
|
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1},
|
||||||
"geometry": {"width": width, "height": height},
|
"geometry": {"width": width, "height": height},
|
||||||
"filters": {
|
"filters": {
|
||||||
"geometry": {
|
"geometry": {
|
||||||
@ -49,7 +51,7 @@ def build_image_info(prediction: dict, metadata: dict) -> dict:
|
|||||||
"tooSmall": min_image_to_page_quotient_breached,
|
"tooSmall": min_image_to_page_quotient_breached,
|
||||||
},
|
},
|
||||||
"imageFormat": {
|
"imageFormat": {
|
||||||
"quotient": width / height,
|
"quotient": round(width / height, 4),
|
||||||
"tooTall": min_image_width_to_height_quotient_breached,
|
"tooTall": min_image_width_to_height_quotient_breached,
|
||||||
"tooWide": max_image_width_to_height_quotient_breached,
|
"tooWide": max_image_width_to_height_quotient_breached,
|
||||||
},
|
},
|
||||||
@ -1,68 +1,3 @@
|
|||||||
import logging
|
|
||||||
import tempfile
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def temporary_pdf_file(pdf: bytes):
|
|
||||||
with tempfile.NamedTemporaryFile() as f:
|
|
||||||
f.write(pdf)
|
|
||||||
yield f.name
|
|
||||||
|
|
||||||
|
|
||||||
def make_logger_getter():
|
|
||||||
|
|
||||||
logger = logging.getLogger("imclf")
|
|
||||||
logger.propagate = False
|
|
||||||
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setLevel(CONFIG.service.logging_level)
|
|
||||||
|
|
||||||
log_format = "[%(levelname)s]: %(message)s"
|
|
||||||
formatter = logging.Formatter(log_format)
|
|
||||||
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
def get_logger():
|
|
||||||
return logger
|
|
||||||
|
|
||||||
return get_logger
|
|
||||||
|
|
||||||
|
|
||||||
get_logger = make_logger_getter()
|
|
||||||
|
|
||||||
|
|
||||||
def show_banner():
|
|
||||||
banner = '''
|
|
||||||
..... . ... ..
|
|
||||||
.d88888Neu. 'L xH88"`~ .x8X x .d88" oec :
|
|
||||||
F""""*8888888F .. . : :8888 .f"8888Hf 5888R @88888
|
|
||||||
* `"*88*" .888: x888 x888. :8888> X8L ^""` '888R 8"*88%
|
|
||||||
-.... ue=:. ~`8888~'888X`?888f` X8888 X888h 888R 8b.
|
|
||||||
:88N ` X888 888X '888> 88888 !88888. 888R u888888>
|
|
||||||
9888L X888 888X '888> 88888 %88888 888R 8888R
|
|
||||||
uzu. `8888L X888 888X '888> 88888 '> `8888> 888R 8888P
|
|
||||||
,""888i ?8888 X888 888X '888> `8888L % ?888 ! 888R *888>
|
|
||||||
4 9888L %888> "*88%""*88" '888!` `8888 `-*"" / .888B . 4888
|
|
||||||
' '8888 '88% `~ " `"` "888. :" ^*888% '888
|
|
||||||
"*8Nu.z*" `""***~"` "% 88R
|
|
||||||
88>
|
|
||||||
48
|
|
||||||
'8
|
|
||||||
'''
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.propagate = False
|
|
||||||
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
formatter = logging.Formatter("")
|
|
||||||
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
logger.info(banner)
|
|
||||||
|
|||||||
1
image_prediction/utils/__init__.py
Normal file
1
image_prediction/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .logger import get_logger
|
||||||
21
image_prediction/utils/banner.py
Normal file
21
image_prediction/utils/banner.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from image_prediction.locations import BANNER_FILE
|
||||||
|
|
||||||
|
|
||||||
|
def show_banner():
|
||||||
|
with open(BANNER_FILE) as f:
|
||||||
|
banner = "\n" + "".join(f.readlines()) + "\n"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
formatter = logging.Formatter("")
|
||||||
|
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
logger.info(banner)
|
||||||
7
image_prediction/utils/generic.py
Normal file
7
image_prediction/utils/generic.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from funcy import iterate, chunks
|
||||||
|
|
||||||
|
|
||||||
|
def until(cond, func, *args, **kwargs):
|
||||||
|
for a, b in chunks(2, iterate(func, *args, **kwargs)):
|
||||||
|
if cond(a, b):
|
||||||
|
return a
|
||||||
29
image_prediction/utils/logger.py
Normal file
29
image_prediction/utils/logger.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from image_prediction.config import CONFIG
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def make_logger_getter():
|
||||||
|
logger = logging.getLogger("imclf")
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(CONFIG.service.logging_level)
|
||||||
|
|
||||||
|
log_format = "%(asctime)s %(levelname)-8s %(message)s"
|
||||||
|
formatter = logging.Formatter(log_format, datefmt="%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
logger.setLevel(CONFIG.service.logging_level)
|
||||||
|
|
||||||
|
def get_logger():
|
||||||
|
return logger
|
||||||
|
|
||||||
|
return get_logger
|
||||||
|
|
||||||
|
|
||||||
|
get_logger = make_logger_getter()
|
||||||
99
image_prediction/utils/pdf_annotation.py
Normal file
99
image_prediction/utils/pdf_annotation.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
"""Defines utilities for PDF processing."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
|
from PDFNetPython3.PDFNetPython import (
|
||||||
|
PDFDoc,
|
||||||
|
PDFNet,
|
||||||
|
Square,
|
||||||
|
Rect,
|
||||||
|
ColorPt,
|
||||||
|
BorderStyle,
|
||||||
|
SDFDoc,
|
||||||
|
Point,
|
||||||
|
Text,
|
||||||
|
)
|
||||||
|
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def annotate_image(doc, image_info):
|
||||||
|
def draw_box():
|
||||||
|
sq = Square.Create(doc.GetSDFDoc(), Rect(*coords))
|
||||||
|
sq.SetColor(ColorPt(*color), 3)
|
||||||
|
sq.SetBorderStyle(BorderStyle(BorderStyle.e_dashed, 2, 0, 0, [4, 2]))
|
||||||
|
sq.SetPadding(4)
|
||||||
|
sq.RefreshAppearance()
|
||||||
|
page.AnnotPushBack(sq)
|
||||||
|
|
||||||
|
def add_note():
|
||||||
|
txt = Text.Create(doc.GetSDFDoc(), Point(*coords[:2]))
|
||||||
|
txt.SetContents(json.dumps(image_info, indent=2, ensure_ascii=False))
|
||||||
|
txt.SetColor(ColorPt(*color))
|
||||||
|
page.AnnotPushBack(txt)
|
||||||
|
txt.RefreshAppearance()
|
||||||
|
|
||||||
|
red = (1, 0, 0)
|
||||||
|
green = (0, 1, 0)
|
||||||
|
blue = (0, 0, 1)
|
||||||
|
|
||||||
|
if image_info["filters"]["allPassed"]:
|
||||||
|
color = green
|
||||||
|
elif image_info["filters"]["probability"]["unconfident"]:
|
||||||
|
color = red
|
||||||
|
else:
|
||||||
|
color = blue
|
||||||
|
|
||||||
|
page = doc.GetPage(image_info["position"]["pageNumber"])
|
||||||
|
coords = itemgetter("x1", "y1", "x2", "y2")(image_info["position"])
|
||||||
|
|
||||||
|
draw_box()
|
||||||
|
add_note()
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
PDFNet.Initialize(
|
||||||
|
"Knecon AG(en.knecon.swiss):OEM:DDA-R::WL+:AMS(20211029):BECC974307DAB4F34B513BC9B2531B24496F6FCB83CD8AC574358A959730B622FABEF5C7"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_metadata_box(pdf_path, metadata, store_path):
|
||||||
|
|
||||||
|
init()
|
||||||
|
|
||||||
|
doc = PDFDoc(pdf_path)
|
||||||
|
|
||||||
|
color = (1, 0, 0)
|
||||||
|
|
||||||
|
print(metadata)
|
||||||
|
|
||||||
|
coords = itemgetter("x1", "y1", "x2", "y2")(metadata)
|
||||||
|
page = doc.GetPage(1)
|
||||||
|
|
||||||
|
sq = Square.Create(doc.GetSDFDoc(), Rect(*coords))
|
||||||
|
sq.SetColor(ColorPt(*color), 3)
|
||||||
|
sq.SetBorderStyle(BorderStyle(BorderStyle.e_dashed, 2, 0, 0, [4, 2]))
|
||||||
|
sq.SetPadding(4)
|
||||||
|
sq.RefreshAppearance()
|
||||||
|
page.AnnotPushBack(sq)
|
||||||
|
|
||||||
|
doc.Save(store_path, SDFDoc.e_linearized)
|
||||||
|
|
||||||
|
logger.info(f"Saved annotated PDF to {store_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def annotate_pdf(pdf_path, responses, store_path):
|
||||||
|
|
||||||
|
init()
|
||||||
|
|
||||||
|
doc = PDFDoc(pdf_path)
|
||||||
|
|
||||||
|
for image_info in responses:
|
||||||
|
annotate_image(doc, image_info)
|
||||||
|
|
||||||
|
doc.Save(store_path, SDFDoc.e_linearized)
|
||||||
|
|
||||||
|
logger.info(f"Saved annotated PDF to {store_path}")
|
||||||
@ -1,2 +1,5 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
norecursedirs = incl
|
norecursedirs = incl
|
||||||
|
filterwarnings =
|
||||||
|
ignore:.*:DeprecationWarning
|
||||||
|
ignore:.*:DeprecationWarning
|
||||||
|
|||||||
@ -1,23 +1,22 @@
|
|||||||
Flask==2.0.2
|
Flask==2.1.1
|
||||||
requests==2.27.1
|
requests==2.27.1
|
||||||
iteration-utilities==0.11.0
|
iteration-utilities==0.11.0
|
||||||
dvc==2.9.3
|
dvc==2.10.0
|
||||||
dvc[ssh]
|
dvc[ssh]
|
||||||
frozendict==2.3.0
|
waitress==2.1.1
|
||||||
waitress==2.0.0
|
envyaml==1.10.211231
|
||||||
envyaml~=1.8.210417
|
|
||||||
dependency-check==0.6.*
|
dependency-check==0.6.*
|
||||||
envyaml~=1.8.210417
|
mlflow==1.24.0
|
||||||
mlflow~=1.20.2
|
numpy==1.22.3
|
||||||
numpy~=1.19.3
|
tqdm==4.64.0
|
||||||
PDFNetPython3~=9.1.0
|
pandas==1.4.2
|
||||||
tqdm~=4.62.2
|
tensorflow==2.8.0
|
||||||
pandas~=1.3.1
|
PyYAML==6.0
|
||||||
mlflow~=1.20.2
|
|
||||||
tensorflow~=2.5.0
|
|
||||||
PDFNetPython3~=9.1.0
|
|
||||||
Pillow~=8.3.2
|
|
||||||
PyYAML~=5.4.1
|
|
||||||
scikit_learn~=0.24.2
|
|
||||||
|
|
||||||
pytest~=7.1.0
|
pytest~=7.1.0
|
||||||
|
funcy==1.17
|
||||||
|
PyMuPDF==1.19.6
|
||||||
|
fpdf==1.7.2
|
||||||
|
coverage==6.3.2
|
||||||
|
Pillow==9.1.0
|
||||||
|
PDFNetPython3==9.1.0
|
||||||
|
pdf2image==1.16.0
|
||||||
@ -40,7 +40,7 @@ def make_predict_fn():
|
|||||||
model = make_model()
|
model = make_model()
|
||||||
|
|
||||||
def predict(*args):
|
def predict(*args):
|
||||||
# model = make_model()
|
# service_estimator = make_model()
|
||||||
return model.predict(np.random.random(size=(1, 784)))
|
return model.predict(np.random.random(size=(1, 784)))
|
||||||
|
|
||||||
return predict
|
return predict
|
||||||
|
|||||||
55
scripts/run_pipeline.py
Normal file
55
scripts/run_pipeline.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
from image_prediction.pipeline import load_pipeline
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
from image_prediction.utils.pdf_annotation import annotate_pdf
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("input", help="pdf file or directory")
|
||||||
|
parser.add_argument("--print", "-p", help="print output to terminal", action="store_true", default=False)
|
||||||
|
parser.add_argument("--page_interval", "-i", help="page interval [i, j), min index = 0", nargs=2, type=int)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def process_pdf(pipeline, pdf_path, page_range=None):
|
||||||
|
with open(pdf_path, "rb") as f:
|
||||||
|
logger.info(f"Processing {pdf_path}")
|
||||||
|
predictions = list(pipeline(f.read(), page_range=page_range))
|
||||||
|
|
||||||
|
annotate_pdf(
|
||||||
|
pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf")))
|
||||||
|
)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
pipeline = load_pipeline(verbose=False, tolerance=3)
|
||||||
|
|
||||||
|
if os.path.isfile(args.input):
|
||||||
|
pdf_paths = [args.input]
|
||||||
|
else:
|
||||||
|
pdf_paths = glob(os.path.join(args.input, "*.pdf"))
|
||||||
|
page_range = range(*args.page_interval) if args.page_interval else None
|
||||||
|
|
||||||
|
for pdf_path in pdf_paths:
|
||||||
|
predictions = process_pdf(pipeline, pdf_path, page_range=page_range)
|
||||||
|
if args.print:
|
||||||
|
print(pdf_path)
|
||||||
|
print(json.dumps(predictions, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
38
src/serve.py
38
src/serve.py
@ -4,45 +4,29 @@ from waitress import serve
|
|||||||
|
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
from image_prediction.flask import make_prediction_server
|
from image_prediction.flask import make_prediction_server
|
||||||
from image_prediction.predictor import Predictor
|
from image_prediction.pipeline import load_pipeline
|
||||||
from image_prediction.response import build_response
|
from image_prediction.utils import get_logger
|
||||||
from image_prediction.utils import get_logger, show_banner
|
from image_prediction.utils.banner import show_banner
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
def predict(pdf):
|
def predict(pdf):
|
||||||
# Keras model.predict stalls when model was loaded in different process
|
# Keras service_estimator.predict stalls when service_estimator was loaded in different process;
|
||||||
|
# therefore, we re-load the model (part of the pipeline) every time we process a new document.
|
||||||
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
|
# https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python
|
||||||
predictor = Predictor()
|
logger.debug("Loading pipeline...")
|
||||||
predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar)
|
pipeline = load_pipeline(verbose=CONFIG.service.verbose)
|
||||||
response = build_response(predictions, metadata)
|
logger.debug("Running pipeline...")
|
||||||
return response
|
return list(pipeline(pdf))
|
||||||
|
|
||||||
logger.info("Predictor ready.")
|
|
||||||
|
|
||||||
prediction_server = make_prediction_server(predict)
|
prediction_server = make_prediction_server(predict)
|
||||||
|
serve(prediction_server, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False)
|
||||||
run_prediction_server(prediction_server, mode=CONFIG.webserver.mode)
|
|
||||||
|
|
||||||
|
|
||||||
def run_prediction_server(app, mode="development"):
|
|
||||||
if mode == "development":
|
|
||||||
app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True)
|
|
||||||
elif mode == "production":
|
|
||||||
serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging_level = CONFIG.service.logging_level
|
logging.basicConfig(level=CONFIG.service.logging_level)
|
||||||
logging.basicConfig(level=logging_level)
|
|
||||||
logging.getLogger("flask").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("waitress").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("PIL").setLevel(logging.ERROR)
|
|
||||||
logging.getLogger("h5py").setLevel(logging.ERROR)
|
|
||||||
|
|
||||||
show_banner()
|
show_banner()
|
||||||
|
|
||||||
|
|||||||
543
test/conftest.py
543
test/conftest.py
@ -1,70 +1,515 @@
|
|||||||
import os.path
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import tempfile
|
||||||
|
from functools import partial
|
||||||
|
from itertools import starmap
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
|
import fpdf
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from PIL import Image
|
||||||
|
from funcy import rcompose, merge
|
||||||
|
|
||||||
from image_prediction.predictor import Predictor
|
from image_prediction.classifier.classifier import Classifier
|
||||||
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
|
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||||
|
from image_prediction.exceptions import (
|
||||||
|
UnknownEstimatorAdapter,
|
||||||
|
UnknownImageExtractor,
|
||||||
|
UnknownDatabaseType,
|
||||||
|
UnknownLabelFormat,
|
||||||
|
)
|
||||||
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
|
from image_prediction.image_extractor.extractors.mock import ImageExtractorMock
|
||||||
|
from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor
|
||||||
|
from image_prediction.info import Info
|
||||||
|
from image_prediction.label_mapper.mappers.numeric import IndexMapper
|
||||||
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapper, ProbabilityMapperKeys
|
||||||
|
from image_prediction.locations import TEST_DATA_DIR
|
||||||
|
from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock
|
||||||
|
from image_prediction.model_loader.loader import ModelLoader
|
||||||
|
from image_prediction.model_loader.loaders.mlflow import MlflowConnector
|
||||||
|
from image_prediction.pipeline import load_pipeline
|
||||||
|
from image_prediction.redai_adapter.mlflow import MlflowModelReader
|
||||||
|
from image_prediction.redai_adapter.model import PredictionModelHandle
|
||||||
|
from image_prediction.utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mute_logger():
|
||||||
|
logger = get_logger()
|
||||||
|
level = logger.level
|
||||||
|
logger.setLevel(logging.CRITICAL + 1)
|
||||||
|
yield
|
||||||
|
logger.setLevel(level)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def predictions():
|
def image_extractor(extractor_type):
|
||||||
|
if extractor_type == "mock":
|
||||||
|
return ImageExtractorMock()
|
||||||
|
elif extractor_type == "parsable_pdf":
|
||||||
|
return ParsablePDFImageExtractor()
|
||||||
|
elif extractor_type == "default":
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels):
|
||||||
|
return ImageClassifier(classifier, preprocessor=BasicPreprocessor())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def classifier(estimator_adapter, label_mapper):
|
||||||
|
classifier = Classifier(estimator_adapter, label_mapper)
|
||||||
|
return classifier
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def estimator_mock():
|
||||||
|
class EstimatorMock:
|
||||||
|
@staticmethod
|
||||||
|
def predict(batch):
|
||||||
|
return [None for _ in batch]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def predict_proba(batch):
|
||||||
|
return [None for _ in batch]
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
return self.predict(batch)
|
||||||
|
|
||||||
|
return EstimatorMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def label_mapper(label_format, classes):
|
||||||
|
if label_format == "index":
|
||||||
|
return IndexMapper(classes)
|
||||||
|
elif label_format == "probability":
|
||||||
|
return ProbabilityMapper(classes)
|
||||||
|
else:
|
||||||
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["index"])
|
||||||
|
def label_format(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_predictions_mapped(
|
||||||
|
label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings
|
||||||
|
):
|
||||||
|
if label_format == "index":
|
||||||
|
return batch_of_expected_string_labels
|
||||||
|
elif label_format == "probability":
|
||||||
|
return batch_of_expected_label_to_probability_mappings
|
||||||
|
else:
|
||||||
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays):
|
||||||
|
if label_format == "index":
|
||||||
|
return batch_of_expected_numeric_labels
|
||||||
|
elif label_format == "probability":
|
||||||
|
return batch_of_expected_probability_arrays
|
||||||
|
else:
|
||||||
|
raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def estimator_adapter(
|
||||||
|
estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch
|
||||||
|
):
|
||||||
|
if estimator_type == "mock":
|
||||||
|
estimator_adapter = EstimatorAdapter(estimator_mock)
|
||||||
|
elif estimator_type == "keras":
|
||||||
|
estimator_adapter = EstimatorAdapter(keras_model)
|
||||||
|
elif estimator_type == "redai":
|
||||||
|
estimator_adapter = EstimatorAdapter(PredictionModelHandle(model_handle_mock))
|
||||||
|
else:
|
||||||
|
raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.")
|
||||||
|
|
||||||
|
def mock_predict(batch):
|
||||||
|
# Run real predict function to test for mechanical issues, but return externally defined
|
||||||
|
# predictions to test the callers of the estimator adapter against the expected predictions
|
||||||
|
return [next(output_batch_generator) for _ in _predict(batch)]
|
||||||
|
|
||||||
|
_predict = estimator_adapter.predict
|
||||||
|
monkeypatch.setattr(estimator_adapter, "predict", mock_predict)
|
||||||
|
|
||||||
|
return estimator_adapter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def keras_model(input_size):
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
tf.keras.backend.set_image_data_format("channels_last")
|
||||||
|
|
||||||
|
inputs = tf.keras.Input(shape=input_size)
|
||||||
|
conv = tf.keras.layers.Conv2D(3, 3)
|
||||||
|
dense = tf.keras.layers.Dense(10)
|
||||||
|
|
||||||
|
outputs = tf.keras.layers.Dense(10)(dense(conv(inputs)))
|
||||||
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
model.compile()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def images(input_batch):
|
||||||
|
return list(map(array_to_image, input_batch))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def input_batch(batch_size, input_size):
|
||||||
|
return np.random.random_sample(size=(batch_size, *input_size))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[0, 1, 2, 16, 32])
|
||||||
|
def batch_size(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def input_size(alpha, __input_size):
|
||||||
|
w, h, d = __input_size
|
||||||
|
return w, h, d + alpha
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[False])
|
||||||
|
def alpha(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}])
|
||||||
|
def __input_size(request):
|
||||||
|
return itemgetter("width", "height", "depth")(request.param)
|
||||||
|
|
||||||
|
|
||||||
|
def array_to_image(array):
|
||||||
|
assert np.all(array <= 1)
|
||||||
|
assert np.all(array >= 0)
|
||||||
|
|
||||||
|
if array.shape[-1] == 3:
|
||||||
|
mode = "RGB"
|
||||||
|
elif array.shape[-1] == 4:
|
||||||
|
mode = "RGBA"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected number of channels {array.shape[-1]}. Expected 3 or 4.")
|
||||||
|
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return Image.fromarray(np.uint8(array * 255), mode=mode)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def batch_of_expected_string_labels(batch_of_expected_numeric_labels, classes):
|
||||||
|
return map_labels(batch_of_expected_numeric_labels, classes)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def batch_of_expected_numeric_labels(batch_size, classes):
|
||||||
|
return random.choices(range(len(classes)), k=batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def batch_of_expected_label_to_probability_mappings(batch_of_expected_probability_arrays, classes):
|
||||||
|
def map_probabilities(probabilities):
|
||||||
|
lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True))
|
||||||
|
most_likely = [*lbl2prob][0]
|
||||||
|
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
|
||||||
|
|
||||||
|
rounder = rcompose(partial(np.round, decimals=4), float)
|
||||||
|
return list(map(map_probabilities, batch_of_expected_probability_arrays))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def batch_of_expected_probability_arrays(batch_size, classes):
|
||||||
|
return [np.random.uniform(size=len(classes)) for _ in range(batch_size)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def output_batch_generator(expected_predictions):
|
||||||
|
return iter(expected_predictions)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def classes():
|
||||||
|
return ["A", "B", "C"]
|
||||||
|
|
||||||
|
|
||||||
|
def map_labels(numeric_labels, classes):
|
||||||
|
return [classes[nl] for nl in numeric_labels]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata):
|
||||||
|
return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted):
|
||||||
return [
|
return [
|
||||||
{
|
{"classification": epm, **mdt}
|
||||||
"class": "signature",
|
for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted)
|
||||||
"probabilities": {
|
|
||||||
"signature": 1.0,
|
|
||||||
"logo": 9.150285377746546e-19,
|
|
||||||
"other": 4.374506412383356e-19,
|
|
||||||
"formula": 3.582569597002796e-24,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def metadata():
|
def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
|
||||||
return [
|
return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped]
|
||||||
{
|
|
||||||
"page_height": 612.0,
|
|
||||||
"page_width": 792.0,
|
@pytest.fixture
|
||||||
"height": 61.049999999999955,
|
def metadata(images, info_label_map):
|
||||||
"width": 139.35000000000002,
|
page_idx = 0
|
||||||
"page_idx": 8,
|
|
||||||
"x1": 63.5,
|
def current_page_idx():
|
||||||
"x2": 202.85000000000002,
|
nonlocal page_idx
|
||||||
"y1": 472.0,
|
page_idx += random.randint(0, 3)
|
||||||
"y2": 533.05,
|
return min(page_idx, len(images) - 1)
|
||||||
|
|
||||||
|
def build_image_metadata(image):
|
||||||
|
width, height = image.size
|
||||||
|
page_width = 595
|
||||||
|
page_height = 842
|
||||||
|
x1 = random.randint(0, page_width - width)
|
||||||
|
x2 = x1 + width
|
||||||
|
y1 = random.randint(0, page_height - height)
|
||||||
|
y2 = y1 + height
|
||||||
|
metadata = {
|
||||||
|
info_label_map.PAGE_WIDTH: page_width,
|
||||||
|
info_label_map.PAGE_HEIGHT: page_height,
|
||||||
|
info_label_map.PAGE_IDX: current_page_idx(),
|
||||||
|
info_label_map.WIDTH: width,
|
||||||
|
info_label_map.HEIGHT: height,
|
||||||
|
info_label_map.X1: x1,
|
||||||
|
info_label_map.X2: x2,
|
||||||
|
info_label_map.Y1: y1,
|
||||||
|
info_label_map.Y2: y2,
|
||||||
|
info_label_map.ALPHA: image.mode == "RGBA",
|
||||||
}
|
}
|
||||||
]
|
return metadata
|
||||||
|
|
||||||
|
return list(map(build_image_metadata, images))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def response():
|
def info_label_map():
|
||||||
return [
|
return Info
|
||||||
{
|
|
||||||
"classification": {
|
|
||||||
"label": "signature",
|
@pytest.fixture
|
||||||
"probabilities": {"formula": 0.0, "logo": 0.0, "other": 0.0, "signature": 1.0},
|
def metadata_formatted(metadata):
|
||||||
},
|
def format_metadata(metadata):
|
||||||
"filters": {
|
return {key.value: val for key, val in metadata.items()}
|
||||||
"allPassed": True,
|
|
||||||
"geometry": {
|
return list(map(format_metadata, metadata))
|
||||||
"imageFormat": {"quotient": 2.282555282555285, "tooTall": False, "tooWide": False},
|
|
||||||
"imageSize": {"quotient": 0.13248234868245012, "tooLarge": False, "tooSmall": False},
|
|
||||||
},
|
@pytest.fixture
|
||||||
"probability": {"unconfident": False},
|
def image_metadata_pairs(images, metadata):
|
||||||
},
|
return list(starmap(ImageMetadataPair, zip(images, metadata)))
|
||||||
"geometry": {"height": 61.049999999999955, "width": 139.35000000000002},
|
|
||||||
"position": {"pageNumber": 9, "x1": 63.5, "x2": 202.85000000000002, "y1": 472.0, "y2": 533.05},
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pdf(image_metadata_pairs):
|
||||||
|
pdf = fpdf.FPDF(unit="pt")
|
||||||
|
|
||||||
|
for pair in image_metadata_pairs:
|
||||||
|
add_image(pdf, pair)
|
||||||
|
|
||||||
|
return pdf_stream(pdf)
|
||||||
|
|
||||||
|
|
||||||
|
def add_image(pdf, image_metadata_pair, suffix="png"):
|
||||||
|
while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf):
|
||||||
|
pdf.add_page()
|
||||||
|
|
||||||
|
add_image_to_last_page(pdf, image_metadata_pair, suffix=suffix)
|
||||||
|
|
||||||
|
|
||||||
|
def fewer_pages_then_required(page_idx, pdf):
|
||||||
|
return page_idx > pdf.page - 1
|
||||||
|
|
||||||
|
|
||||||
|
def pdf_stream(pdf: fpdf.fpdf.FPDF):
|
||||||
|
return pdf.output(dest="S").encode("latin1")
|
||||||
|
|
||||||
|
|
||||||
|
def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix):
|
||||||
|
image, metadata = image_metadata_pair
|
||||||
|
x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image:
|
||||||
|
image.save(temp_image.name)
|
||||||
|
pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model():
|
||||||
|
class Model:
|
||||||
|
@staticmethod
|
||||||
|
def predict(*args):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def predict_proba(*args):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return Model()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_database_record_identifier():
|
||||||
|
return "".join(random.sample(string.ascii_letters, k=10))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_database_record(model, classes):
|
||||||
|
return {"model": model, "classes": classes}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_database(model_database_record, model_database_record_identifier):
|
||||||
|
return {model_database_record_identifier: model_database_record}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def database_connector(database_type, model_database, mlflow_reader):
|
||||||
|
if database_type == "mock":
|
||||||
|
return DatabaseConnectorMock(model_database)
|
||||||
|
|
||||||
|
elif database_type == "mlflow":
|
||||||
|
return MlflowConnector(mlflow_reader)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_loader(database_connector):
|
||||||
|
return ModelLoader(database_connector)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mlflow_run_id():
|
||||||
|
from image_prediction.config import CONFIG
|
||||||
|
|
||||||
|
return CONFIG.service.run_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mlruns_dir():
|
||||||
|
from image_prediction.locations import MLRUNS_DIR
|
||||||
|
|
||||||
|
return MLRUNS_DIR
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mlflow_reader(mlruns_dir):
|
||||||
|
return MlflowModelReader(mlruns_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_handle_mock(estimator_mock):
|
||||||
|
class ModelHandleMock:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = estimator_mock
|
||||||
|
|
||||||
|
def prep_images(self, batch):
|
||||||
|
return [None for _ in batch]
|
||||||
|
|
||||||
|
def predict(self, batch):
|
||||||
|
return [None for _ in batch]
|
||||||
|
|
||||||
|
def predict_proba(self, batch):
|
||||||
|
return [None for _ in batch]
|
||||||
|
|
||||||
|
return ModelHandleMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def real_pdf():
|
||||||
|
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f:
|
||||||
|
yield f.read()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def real_expected_service_response():
|
||||||
|
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
|
||||||
|
yield json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline():
|
||||||
|
pipeline = load_pipeline(verbose=False)
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def transform_equal(a, b):
|
||||||
|
return (list(a) if isinstance(a, map) else a) == b
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_position_metadata(width, height, page_width, page_height):
|
||||||
|
return {
|
||||||
|
Info.WIDTH: width,
|
||||||
|
Info.HEIGHT: height,
|
||||||
|
Info.PAGE_IDX: 0,
|
||||||
|
Info.PAGE_WIDTH: page_width,
|
||||||
|
Info.PAGE_HEIGHT: page_height,
|
||||||
}
|
}
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def predictor():
|
def base_patch_metadata(width, height, page_width, page_height):
|
||||||
return Predictor()
|
metadata = get_base_position_metadata(width, height, page_width, page_height)
|
||||||
|
metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height})
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(params=[33, 100])
|
||||||
def test_pdf():
|
def height(request):
|
||||||
with open("./test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf", "rb") as f:
|
return request.param
|
||||||
return f.read()
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[10, 31])
|
||||||
|
def width(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[220, 30])
|
||||||
|
def page_height(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[100, 310])
|
||||||
|
def page_width(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
def random_single_color_image_from_metadata(metadata):
|
||||||
|
image = Image.new(
|
||||||
|
"RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=tuple(map(int, np.random.uniform(size=3) * 255))
|
||||||
|
)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def gray_image_from_metadata(metadata):
|
||||||
|
image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100))
|
||||||
|
return image
|
||||||
|
|||||||
42
test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json
Normal file
42
test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"classification": {
|
||||||
|
"label": "formula",
|
||||||
|
"probabilities": {
|
||||||
|
"formula": 1.0,
|
||||||
|
"logo": 0.0,
|
||||||
|
"other": 0.0,
|
||||||
|
"signature": 0.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x1": 321,
|
||||||
|
"x2": 515,
|
||||||
|
"y1": 348,
|
||||||
|
"y2": 542,
|
||||||
|
"pageNumber": 2
|
||||||
|
},
|
||||||
|
"geometry": {
|
||||||
|
"width": 194,
|
||||||
|
"height": 194
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"geometry": {
|
||||||
|
"imageSize": {
|
||||||
|
"quotient": 0.2741,
|
||||||
|
"tooLarge": false,
|
||||||
|
"tooSmall": false
|
||||||
|
},
|
||||||
|
"imageFormat": {
|
||||||
|
"quotient": 1.0,
|
||||||
|
"tooTall": false,
|
||||||
|
"tooWide": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"probability": {
|
||||||
|
"unconfident": false
|
||||||
|
},
|
||||||
|
"allPassed": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
92
test/data/stitching_with_tolerance.json
Normal file
92
test/data/stitching_with_tolerance.json
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
{
|
||||||
|
"input": [
|
||||||
|
{
|
||||||
|
"width": 100,
|
||||||
|
"height": 8,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 0,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 8
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 100,
|
||||||
|
"height": 9,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 9,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 18
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 100,
|
||||||
|
"height": 35,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 18,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 53
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 47,
|
||||||
|
"height": 46,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 54,
|
||||||
|
"x2": 47,
|
||||||
|
"y2": 100
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 31,
|
||||||
|
"height": 46,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 48,
|
||||||
|
"y1": 54,
|
||||||
|
"x2": 79,
|
||||||
|
"y2": 100
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 20,
|
||||||
|
"height": 19,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 80,
|
||||||
|
"y1": 54,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 73
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"width": 20,
|
||||||
|
"height": 27,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 80,
|
||||||
|
"y1": 73,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 100
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"target": {
|
||||||
|
"width": 100,
|
||||||
|
"height": 100,
|
||||||
|
"page_idx": 0,
|
||||||
|
"page_width": 100,
|
||||||
|
"page_height": 100,
|
||||||
|
"x1": 0,
|
||||||
|
"y1": 0,
|
||||||
|
"x2": 100,
|
||||||
|
"y2": 100
|
||||||
|
}
|
||||||
|
}
|
||||||
32
test/exploration_tests/funcy_test.py
Normal file
32
test/exploration_tests/funcy_test.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
from funcy import rcompose, chunks
|
||||||
|
|
||||||
|
|
||||||
|
def test_rcompose():
|
||||||
|
f = rcompose(lambda x: x ** 2, str, lambda x: x * 2)
|
||||||
|
assert f(3) == "99"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_iterable_exact_split():
|
||||||
|
a, b = chunks(5, iter(range(10)))
|
||||||
|
assert a == list(range(5))
|
||||||
|
assert b == list(range(5, 10))
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_iterable_no_split():
|
||||||
|
a = next(chunks(10, iter(range(10))))
|
||||||
|
assert a == list(range(10))
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_iterable_last_partial():
|
||||||
|
a, b, c, d = chunks(3, iter(range(10)))
|
||||||
|
assert d == [9]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_iterable_empty():
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
next(chunks(3, iter(range(0))))
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunk_iterable_less_than_chunk_size_elements():
|
||||||
|
assert next(chunks(5, iter(range(2)))) == [0, 1]
|
||||||
102
test/integration_tests/actual_server_test.py
Normal file
102
test/integration_tests/actual_server_test.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import socket
|
||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from funcy import retry
|
||||||
|
from waitress import serve
|
||||||
|
|
||||||
|
from image_prediction.flask import make_prediction_server
|
||||||
|
from image_prediction.pipeline import load_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def host():
|
||||||
|
return "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
def get_free_port(host):
|
||||||
|
sock = socket.socket()
|
||||||
|
sock.bind((host, 0))
|
||||||
|
return sock.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def port(host):
|
||||||
|
return get_free_port(host)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def url(host, port):
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["dummy", "actual"])
|
||||||
|
def server_type(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def server(server_type):
|
||||||
|
if server_type == "dummy":
|
||||||
|
return make_prediction_server(lambda x: int(x.decode()) // 2)
|
||||||
|
|
||||||
|
elif server_type == "actual":
|
||||||
|
return make_prediction_server(lambda x: list(load_pipeline(verbose=False)(x)))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown server type {server_type}.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def host_and_port(host, port, server):
|
||||||
|
return {"host": host, "port": port}
|
||||||
|
|
||||||
|
|
||||||
|
@retry(tries=5, timeout=1)
|
||||||
|
def server_ready(url):
|
||||||
|
response = requests.get(f"{url}/ready")
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True, scope="function")
|
||||||
|
def server_process(server, host_and_port, url):
|
||||||
|
def get_server_process():
|
||||||
|
return Process(target=serve, kwargs={"app": server, **host_and_port})
|
||||||
|
|
||||||
|
server = get_server_process()
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
if server_ready(url):
|
||||||
|
yield
|
||||||
|
|
||||||
|
server.kill()
|
||||||
|
server.join()
|
||||||
|
server.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("server_type", ["actual"])
|
||||||
|
def test_server_predict(url, real_pdf, real_expected_service_response):
|
||||||
|
response = requests.post(f"{url}/predict", data=real_pdf)
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.json() == real_expected_service_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("server_type", ["dummy"])
|
||||||
|
def test_server_dummy_operation(url):
|
||||||
|
response = requests.post(f"{url}/predict", data=b"42")
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.json() == 21
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("server_type", ["dummy"])
|
||||||
|
def test_server_health_check(url):
|
||||||
|
response = requests.get(f"{url}/health")
|
||||||
|
response.raise_for_status()
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("server_type", ["dummy"])
|
||||||
|
def test_server_ready_check(url):
|
||||||
|
assert server_ready(url)
|
||||||
29
test/unit_tests/box_validation_test.py
Normal file
29
test/unit_tests/box_validation_test.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from image_prediction.exceptions import InvalidBox
|
||||||
|
from image_prediction.info import Info
|
||||||
|
from image_prediction.stitching.utils import validate_box_size, validate_box_coords
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_fail_too_short():
|
||||||
|
box = {Info.WIDTH: 1, Info.HEIGHT: 0}
|
||||||
|
with pytest.raises(InvalidBox):
|
||||||
|
validate_box_size(box)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_fail_too_thin():
|
||||||
|
box = {Info.WIDTH: 0, Info.HEIGHT: 1}
|
||||||
|
with pytest.raises(InvalidBox):
|
||||||
|
validate_box_size(box)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_fail_xs_width_mismatch():
|
||||||
|
box = {Info.WIDTH: 2, Info.HEIGHT: 4, Info.X1: 0, Info.Y1: 0, Info.X2: 1, Info.Y2: 4}
|
||||||
|
with pytest.raises(InvalidBox):
|
||||||
|
validate_box_coords(box)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_fail_ys_height_mismatch():
|
||||||
|
box = {Info.WIDTH: 2, Info.HEIGHT: 3, Info.X1: 0, Info.Y1: 0, Info.X2: 2, Info.Y2: 4}
|
||||||
|
with pytest.raises(InvalidBox):
|
||||||
|
validate_box_coords(box)
|
||||||
19
test/unit_tests/classifier_test.py
Normal file
19
test/unit_tests/classifier_test.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"])
|
||||||
|
@pytest.mark.parametrize("label_format", ["index", "probability"])
|
||||||
|
def test_classifier(classifier, input_batch, expected_predictions_mapped):
|
||||||
|
predictions = classifier(input_batch)
|
||||||
|
assert predictions == expected_predictions_mapped
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_format(input_batch):
|
||||||
|
def channels_are_last(input_batch):
|
||||||
|
return input_batch.shape[-1] == 3
|
||||||
|
|
||||||
|
def is_fourth_order_tensor(input_batch):
|
||||||
|
return input_batch.ndim == 4
|
||||||
|
|
||||||
|
assert channels_are_last(input_batch)
|
||||||
|
assert is_fourth_order_tensor(input_batch)
|
||||||
32
test/unit_tests/compositor_test.py
Normal file
32
test/unit_tests/compositor_test.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from image_prediction.compositor.compositor import TransformerCompositor
|
||||||
|
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
||||||
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||||
|
from image_prediction.formatter.formatters.identity import IdentityFormatter
|
||||||
|
from test.conftest import transform_equal
|
||||||
|
|
||||||
|
|
||||||
|
def test_identity(metadata):
|
||||||
|
compositor = TransformerCompositor(IdentityFormatter())
|
||||||
|
assert transform_equal(compositor(metadata), metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def test_composition(metadata, metadata_formatted):
|
||||||
|
compositor = TransformerCompositor(IdentityFormatter(), EnumFormatter())
|
||||||
|
assert transform_equal(compositor(metadata), metadata_formatted)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def compositor_test_enum_metadata(info_label_map):
|
||||||
|
return [{info_label_map.WIDTH: 100, info_label_map.PAGE_WIDTH: 200}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def compositor_test_camel_case_metadata(info_label_map):
|
||||||
|
return [{"width": 100, "pageWidth": 200}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_enum_to_camel_case(compositor_test_enum_metadata, compositor_test_camel_case_metadata):
|
||||||
|
compositor = TransformerCompositor(EnumFormatter(), Snake2CamelCaseKeyFormatter())
|
||||||
|
assert transform_equal(compositor(compositor_test_enum_metadata), compositor_test_camel_case_metadata)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user