From ddd8d4685e72319fb61814aad8f3b0147bdd0351 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Mon, 25 Apr 2022 12:25:41 +0200 Subject: [PATCH] Pull request #9: Tdd refactoring Merge in RR/image-prediction from tdd_refactoring to master Squashed commit of the following: commit f6c64430007590f5d2b234a7f784e26025d06484 Author: Matthias Bisping Date: Mon Apr 25 12:18:47 2022 +0200 renaming commit 8f40b51282191edf3e2a5edcd6d6acb388ada453 Author: Matthias Bisping Date: Mon Apr 25 12:07:18 2022 +0200 adjusted expetced output for alpha channel in response commit 7e666302d5eadb1e84b70cae27e8ec6108d7a135 Author: Matthias Bisping Date: Mon Apr 25 11:52:51 2022 +0200 added alpha channel check result to response commit a6b9f64b51cd888fc0c427a38bd43ae2ae2cb051 Author: Matthias Bisping Date: Mon Apr 25 11:27:57 2022 +0200 readme updated commit 0d06ad657e3c21dcef361c53df37b05aba64528b Author: Matthias Bisping Date: Mon Apr 25 11:19:35 2022 +0200 readme updated and config commit 75748a1d82f0ebdf3ad7d348c6d820c8858aa3cb Author: Matthias Bisping Date: Mon Apr 25 11:19:26 2022 +0200 refactoring commit 60101337828d11f5ee5fed0d8c4ec80cde536d8a Author: Matthias Bisping Date: Mon Apr 25 11:18:23 2022 +0200 multiple reoutes for prediction commit c8476cb5f55e470b831ae4557a031a2c1294eb86 Author: Matthias Bisping Date: Mon Apr 25 11:17:49 2022 +0200 add banner.txt to container commit 26ef5fce8a9bc015f1c35f32d40e8bea50a96454 Author: Matthias Bisping Date: Mon Apr 25 10:08:49 2022 +0200 Pull request #8: Pipeline refactoring Merge in RR/image-prediction from pipeline_refactoring to tdd_refactoring Squashed commit of the following: commit 6989fcb3313007b7eecf4bba39077fcde6924a9a Author: Matthias Bisping Date: Mon Apr 25 09:49:49 2022 +0200 removed obsolete module commit 7428aeee37b11c31cffa597c85b018ba71e79a1d Author: Matthias Bisping Date: Mon Apr 25 09:45:45 2022 +0200 refactoring commit 0dcd3894154fdf34bd3ba4ef816362434474f472 Author: Matthias Bisping Date: Mon Apr 25 08:57:21 2022 +0200 refactoring; removed obsolete extractor-classifier commit 1078aa81144f4219149b3fcacdae8b09c4b905c0 Author: Matthias Bisping Date: Fri Apr 22 17:18:10 2022 +0200 removed obsolete imports commit 71f61fc5fc915da3941cf5ed5d9cc90fccc49031 Author: Matthias Bisping Date: Fri Apr 22 17:16:25 2022 +0200 comment changed commit b582726cd1de233edb55c5a76c91e99f9dd3bd13 Author: Matthias Bisping Date: Fri Apr 22 17:12:11 2022 +0200 refactoring commit 8abc9010048078868b235d6793ac6c8b20abb985 Author: Matthias Bisping Date: Thu Apr 21 21:25:47 2022 +0200 formatting commit 2c87c419fe3185a25c27139e7fcf79f60971ad24 Author: Matthias Bisping Date: Thu Apr 21 21:24:05 2022 +0200 formatting commit 50b161192db43a84464125c6d79650225e1010d6 Author: Matthias Bisping Date: Thu Apr 21 21:20:18 2022 +0200 refactoring commit 9a1446cccfa070852a5d9c0bdbc36037b82541fc Author: Matthias Bisping Date: Thu Apr 21 21:04:57 2022 +0200 refactoring commit 6c10b55ff8e61412cb2fe5a5625e660ecaf1d7d1 Author: Matthias Bisping Date: Thu Apr 21 19:48:05 2022 +0200 refactoring commit 72e785e3e31c132ab352119e9921725f91fac9e2 Author: Matthias Bisping Date: Thu Apr 21 19:43:39 2022 +0200 refactoring commit f036ee55e6747daf31e3929bdc2d93dc5f2a56ca Author: Matthias Bisping Date: Wed Apr 20 18:30:41 2022 +0200 refactoring pipeline WIP commit 120721f5f1a7e910c0c2ebc79dc87c2908794c80 Author: Matthias Bisping Date: Wed Apr 20 15:39:58 2022 +0200 rm debug ls commit 81226d4f8599af0db0e9718fbb1789cfad91a855 Author: Matthias Bisping Date: Wed Apr 20 15:28:27 2022 +0200 no compose down commit 943f7799d49b6a6b0fed985a76ed4fe725dfaeef Author: Matthias Bisping Date: Wed Apr 20 15:22:17 2022 +0200 coverage combine commit d4cd96607157ea414db417cfd7133f56cb56afe1 Author: Matthias Bisping Date: Wed Apr 20 14:43:09 2022 +0200 model builder path in mlruns adjusted commit 5b90bb47c3421feb6123c179eb68d1125d58ff1e Author: Matthias Bisping Date: Wed Apr 20 10:56:58 2022 +0200 dvc pull in test running script commit a935cacf2305a4a78a15ff571f368962f4538369 Author: Matthias Bisping Date: Wed Apr 20 10:50:36 2022 +0200 no clean working dir commit ba09df7884485b8ab8efbf42a8058de9af60c75c Author: Matthias Bisping Date: Wed Apr 20 10:43:22 2022 +0200 debug ls commit 71263a9983dbfe2060ef5b74de7cc2cbbad43416 Author: Matthias Bisping Date: Wed Apr 20 09:11:03 2022 +0200 debug ls commit 41fbadc331e65e4ffe6d053e2d925e5e0543d8b7 Author: Matthias Bisping Date: Tue Apr 19 20:08:08 2022 +0200 debug echo commit bb19698d640b3a99ea404e5b4b06d719a9bfe9e9 Author: Matthias Bisping Date: Tue Apr 19 20:01:59 2022 +0200 skip server predict test commit 5094015a87fc0976c9d3ff5d1f4c6fdbd96b7eae Author: Matthias Bisping Date: Tue Apr 19 19:05:50 2022 +0200 sonar stage after build stage ... and 253 more commits --- .coveragerc | 11 +- .dvc/config | 3 +- .gitignore | 6 +- .gitmodules | 3 - Dockerfile | 8 +- Dockerfile_tests | 23 ++ README.md | 133 +++++++++- .../src/main/java/buildjob/PlanSpec.java | 14 +- .../main/resources/scripts/docker-build.sh | 5 +- .../src/main/resources/scripts/sonar-scan.sh | 12 +- banner.txt | 11 + config.yaml | 16 +- data/.gitignore | 1 + data/base_weights.h5.dvc | 4 - data/mlruns.dvc | 6 +- doc/tests.drawio | 1 + .../classifier}/__init__.py | 0 image_prediction/classifier/classifier.py | 35 +++ .../classifier/image_classifier.py | 32 +++ image_prediction/compositor/__init__.py | 0 image_prediction/compositor/compositor.py | 16 ++ image_prediction/config.py | 6 +- image_prediction/default_objects.py | 38 +++ image_prediction/estimator/__init__.py | 0 .../estimator/adapter/__init__.py | 0 image_prediction/estimator/adapter/adapter.py | 15 ++ .../estimator/adapter/adapters/__init__.py | 0 .../estimator/preprocessor/__init__.py | 0 .../estimator/preprocessor/preprocessor.py | 10 + .../preprocessor/preprocessors/__init__.py | 0 .../preprocessor/preprocessors/basic.py | 10 + .../preprocessor/preprocessors/identity.py | 10 + .../estimator/preprocessor/utils.py | 10 + image_prediction/exceptions.py | 34 +++ image_prediction/extraction.py | 13 + image_prediction/flask.py | 69 ++--- image_prediction/formatter/__init__.py | 0 image_prediction/formatter/formatter.py | 15 ++ .../formatter/formatters/__init__.py | 0 .../formatter/formatters/camel_case.py | 11 + image_prediction/formatter/formatters/enum.py | 23 ++ .../formatter/formatters/identity.py | 6 + .../formatter/formatters/key_formatter.py | 28 ++ image_prediction/image_extractor/__init__.py | 0 image_prediction/image_extractor/extractor.py | 19 ++ .../image_extractor/extractors/__init__.py | 0 .../image_extractor/extractors/mock.py | 7 + .../image_extractor/extractors/parsable.py | 179 +++++++++++++ image_prediction/info.py | 14 + image_prediction/label_mapper/__init__.py | 0 image_prediction/label_mapper/mapper.py | 10 + .../label_mapper/mappers/__init__.py | 0 .../label_mapper/mappers/numeric.py | 20 ++ .../label_mapper/mappers/probability.py | 39 +++ image_prediction/locations.py | 21 +- image_prediction/model_loader/__init__.py | 0 .../model_loader/database/__init__.py | 0 .../model_loader/database/connector.py | 7 + .../database/connectors/__init__.py | 0 .../model_loader/database/connectors/mock.py | 9 + image_prediction/model_loader/loader.py | 18 ++ .../model_loader/loaders/__init__.py | 0 .../model_loader/loaders/mlflow.py | 10 + image_prediction/pipeline.py | 64 +++++ image_prediction/predictor.py | 122 --------- image_prediction/redai_adapter/__init__.py | 0 .../redai_adapter/efficient_net_wrapper.py | 45 ++++ image_prediction/redai_adapter/mlflow.py | 72 +++++ image_prediction/redai_adapter/model.py | 19 ++ .../redai_adapter/model_wrapper.py | 42 +++ image_prediction/stitching/__init__.py | 0 image_prediction/stitching/grouping.py | 63 +++++ image_prediction/stitching/merging.py | 189 ++++++++++++++ image_prediction/stitching/split_mapper.py | 40 +++ image_prediction/stitching/stitching.py | 15 ++ image_prediction/stitching/utils.py | 67 +++++ image_prediction/transformer/__init__.py | 0 image_prediction/transformer/transformer.py | 20 ++ .../transformer/transformers/__init__.py | 0 .../transformers/coordinate/__init__.py | 0 .../coordinate/coordinate_transformer.py | 22 ++ .../transformers/coordinate/fitz.py | 10 + .../transformers/coordinate/fpdf.py | 10 + .../transformers/coordinate/pdfnet.py | 18 ++ .../transformers}/response.py | 37 +-- image_prediction/utils.py | 65 ----- image_prediction/utils/__init__.py | 1 + image_prediction/utils/banner.py | 21 ++ image_prediction/utils/generic.py | 15 ++ image_prediction/utils/logger.py | 27 ++ image_prediction/utils/pdf_annotation.py | 99 +++++++ incl/redai_image | 1 - pytest.ini | 5 +- requirements.txt | 38 +-- run_tests.sh | 15 ++ scripts/keras_MnWE.py | 2 +- scripts/pyinfra_mock.py | 2 +- scripts/run_pipeline.py | 55 ++++ setup/docker.sh | 15 -- src/serve.py | 44 +--- test/conftest.py | 85 ++---- .../f2dc689ca794fccb8cd38b95f2bf6ba9.pdf | Bin ...ca794fccb8cd38b95f2bf6ba9_predictions.json | 43 +++ test/data/stitching_with_tolerance.json | 92 +++++++ test/exploration_tests/funcy_test.py | 32 +++ test/fixtures/__init__.py | 0 test/fixtures/extractor.py | 17 ++ test/fixtures/image.py | 14 + test/fixtures/image_metadata_pair.py | 10 + test/fixtures/input.py | 7 + test/fixtures/label.py | 25 ++ test/fixtures/metadata.py | 56 ++++ test/fixtures/model.py | 113 ++++++++ test/fixtures/model_store.py | 61 +++++ test/fixtures/parameters.py | 38 +++ test/fixtures/pdf.py | 23 ++ test/fixtures/target.py | 91 +++++++ test/integration_tests/actual_server_test.py | 103 ++++++++ test/unit_tests/box_validation_test.py | 29 +++ test/unit_tests/classifier_test.py | 19 ++ test/unit_tests/compositor_test.py | 33 +++ test/unit_tests/config_test.py | 38 +++ .../unit_tests/coordinate_transformer_test.py | 239 +++++++++++++++++ test/unit_tests/formatter_test.py | 27 ++ test/unit_tests/image_classifier_test.py | 7 + test/unit_tests/image_extractor_test.py | 77 ++++++ test/unit_tests/image_stitching_test.py | 246 ++++++++++++++++++ test/unit_tests/label_mapper_test.py | 21 ++ test/unit_tests/mocked_server_test.py | 48 ++++ test/unit_tests/model_loader_test.py | 21 ++ test/unit_tests/pipeline_test.py | 7 + test/unit_tests/preprocessor_test.py | 38 +++ test/unit_tests/split_mapper_test.py | 50 ++++ test/unit_tests/test_predictor.py | 26 -- test/unit_tests/test_response.py | 5 - test/unit_tests/utils_test.py | 8 + test/utils/__init__.py | 0 test/utils/comparison.py | 40 +++ test/utils/generation/__init__.py | 0 test/utils/generation/image.py | 31 +++ test/utils/generation/pdf.py | 30 +++ test/utils/label.py | 2 + test/utils/metadata.py | 11 + test/utils/stitching.py | 80 ++++++ 144 files changed, 3735 insertions(+), 459 deletions(-) create mode 100644 Dockerfile_tests create mode 100644 banner.txt create mode 100644 data/.gitignore delete mode 100644 data/base_weights.h5.dvc create mode 100644 doc/tests.drawio rename {test/unit_tests => image_prediction/classifier}/__init__.py (100%) create mode 100644 image_prediction/classifier/classifier.py create mode 100644 image_prediction/classifier/image_classifier.py create mode 100644 image_prediction/compositor/__init__.py create mode 100644 image_prediction/compositor/compositor.py create mode 100644 image_prediction/default_objects.py create mode 100644 image_prediction/estimator/__init__.py create mode 100644 image_prediction/estimator/adapter/__init__.py create mode 100644 image_prediction/estimator/adapter/adapter.py create mode 100644 image_prediction/estimator/adapter/adapters/__init__.py create mode 100644 image_prediction/estimator/preprocessor/__init__.py create mode 100644 image_prediction/estimator/preprocessor/preprocessor.py create mode 100644 image_prediction/estimator/preprocessor/preprocessors/__init__.py create mode 100644 image_prediction/estimator/preprocessor/preprocessors/basic.py create mode 100644 image_prediction/estimator/preprocessor/preprocessors/identity.py create mode 100644 image_prediction/estimator/preprocessor/utils.py create mode 100644 image_prediction/exceptions.py create mode 100644 image_prediction/extraction.py create mode 100644 image_prediction/formatter/__init__.py create mode 100644 image_prediction/formatter/formatter.py create mode 100644 image_prediction/formatter/formatters/__init__.py create mode 100644 image_prediction/formatter/formatters/camel_case.py create mode 100644 image_prediction/formatter/formatters/enum.py create mode 100644 image_prediction/formatter/formatters/identity.py create mode 100644 image_prediction/formatter/formatters/key_formatter.py create mode 100644 image_prediction/image_extractor/__init__.py create mode 100644 image_prediction/image_extractor/extractor.py create mode 100644 image_prediction/image_extractor/extractors/__init__.py create mode 100644 image_prediction/image_extractor/extractors/mock.py create mode 100644 image_prediction/image_extractor/extractors/parsable.py create mode 100644 image_prediction/info.py create mode 100644 image_prediction/label_mapper/__init__.py create mode 100644 image_prediction/label_mapper/mapper.py create mode 100644 image_prediction/label_mapper/mappers/__init__.py create mode 100644 image_prediction/label_mapper/mappers/numeric.py create mode 100644 image_prediction/label_mapper/mappers/probability.py create mode 100644 image_prediction/model_loader/__init__.py create mode 100644 image_prediction/model_loader/database/__init__.py create mode 100644 image_prediction/model_loader/database/connector.py create mode 100644 image_prediction/model_loader/database/connectors/__init__.py create mode 100644 image_prediction/model_loader/database/connectors/mock.py create mode 100644 image_prediction/model_loader/loader.py create mode 100644 image_prediction/model_loader/loaders/__init__.py create mode 100644 image_prediction/model_loader/loaders/mlflow.py create mode 100644 image_prediction/pipeline.py delete mode 100644 image_prediction/predictor.py create mode 100644 image_prediction/redai_adapter/__init__.py create mode 100644 image_prediction/redai_adapter/efficient_net_wrapper.py create mode 100644 image_prediction/redai_adapter/mlflow.py create mode 100644 image_prediction/redai_adapter/model.py create mode 100644 image_prediction/redai_adapter/model_wrapper.py create mode 100644 image_prediction/stitching/__init__.py create mode 100644 image_prediction/stitching/grouping.py create mode 100644 image_prediction/stitching/merging.py create mode 100644 image_prediction/stitching/split_mapper.py create mode 100644 image_prediction/stitching/stitching.py create mode 100644 image_prediction/stitching/utils.py create mode 100644 image_prediction/transformer/__init__.py create mode 100644 image_prediction/transformer/transformer.py create mode 100644 image_prediction/transformer/transformers/__init__.py create mode 100644 image_prediction/transformer/transformers/coordinate/__init__.py create mode 100644 image_prediction/transformer/transformers/coordinate/coordinate_transformer.py create mode 100644 image_prediction/transformer/transformers/coordinate/fitz.py create mode 100644 image_prediction/transformer/transformers/coordinate/fpdf.py create mode 100644 image_prediction/transformer/transformers/coordinate/pdfnet.py rename image_prediction/{ => transformer/transformers}/response.py (68%) create mode 100644 image_prediction/utils/__init__.py create mode 100644 image_prediction/utils/banner.py create mode 100644 image_prediction/utils/generic.py create mode 100644 image_prediction/utils/logger.py create mode 100644 image_prediction/utils/pdf_annotation.py delete mode 160000 incl/redai_image create mode 100755 run_tests.sh create mode 100644 scripts/run_pipeline.py delete mode 100755 setup/docker.sh rename test/{test_data => data}/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf (100%) create mode 100644 test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json create mode 100644 test/data/stitching_with_tolerance.json create mode 100644 test/exploration_tests/funcy_test.py create mode 100644 test/fixtures/__init__.py create mode 100644 test/fixtures/extractor.py create mode 100644 test/fixtures/image.py create mode 100644 test/fixtures/image_metadata_pair.py create mode 100644 test/fixtures/input.py create mode 100644 test/fixtures/label.py create mode 100644 test/fixtures/metadata.py create mode 100644 test/fixtures/model.py create mode 100644 test/fixtures/model_store.py create mode 100644 test/fixtures/parameters.py create mode 100644 test/fixtures/pdf.py create mode 100644 test/fixtures/target.py create mode 100644 test/integration_tests/actual_server_test.py create mode 100644 test/unit_tests/box_validation_test.py create mode 100644 test/unit_tests/classifier_test.py create mode 100644 test/unit_tests/compositor_test.py create mode 100644 test/unit_tests/config_test.py create mode 100644 test/unit_tests/coordinate_transformer_test.py create mode 100644 test/unit_tests/formatter_test.py create mode 100644 test/unit_tests/image_classifier_test.py create mode 100644 test/unit_tests/image_extractor_test.py create mode 100644 test/unit_tests/image_stitching_test.py create mode 100644 test/unit_tests/label_mapper_test.py create mode 100644 test/unit_tests/mocked_server_test.py create mode 100644 test/unit_tests/model_loader_test.py create mode 100644 test/unit_tests/pipeline_test.py create mode 100644 test/unit_tests/preprocessor_test.py create mode 100644 test/unit_tests/split_mapper_test.py delete mode 100644 test/unit_tests/test_predictor.py delete mode 100644 test/unit_tests/test_response.py create mode 100644 test/unit_tests/utils_test.py create mode 100644 test/utils/__init__.py create mode 100644 test/utils/comparison.py create mode 100644 test/utils/generation/__init__.py create mode 100644 test/utils/generation/image.py create mode 100644 test/utils/generation/pdf.py create mode 100644 test/utils/label.py create mode 100644 test/utils/metadata.py create mode 100644 test/utils/stitching.py diff --git a/.coveragerc b/.coveragerc index 81a0e9a..2465361 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,9 @@ # .coveragerc to control coverage.py [run] branch = True +parallel = True +command_line = -m pytest +concurrency = multiprocessing omit = */site-packages/* */distutils/* @@ -11,9 +14,11 @@ omit = */env/* */build_venv/* */build_env/* + */utils/banner.py + */utils/logger.py + */src/* source = image_prediction - src relative_files = True data_file = .coverage @@ -44,6 +49,10 @@ omit = */env/* */build_venv/* */build_env/* + */utils/banner.py + */utils/logger.py + */src/* + */pdf_annotation.py ignore_errors = True diff --git a/.dvc/config b/.dvc/config index 9277694..45a3243 100644 --- a/.dvc/config +++ b/.dvc/config @@ -1,5 +1,6 @@ [core] remote = vector + autostage = true ['remote "vector"'] - url = ssh://vector.iqser.com/research/image_service/ + url = ssh://vector.iqser.com/research/image-prediction/ port = 22 diff --git a/.gitignore b/.gitignore index a14b81f..9e95b52 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,9 @@ **/classpath-data.json **/dependencies-and-licenses-overview.txt +.coverage +.coverage\.*\.* + *__pycache__ *.egg-info* @@ -44,7 +47,6 @@ *misc /coverage_html_report/ -.coverage # Created by https://www.toptal.com/developers/gitignore/api/linux,pycharm # Edit at https://www.toptal.com/developers/gitignore?templates=linux,pycharm @@ -171,5 +173,3 @@ fabric.properties .idea/codestream.xml # End of https://www.toptal.com/developers/gitignore/api/linux,pycharm -/image_prediction/data/mlruns/ -/data/mlruns/ diff --git a/.gitmodules b/.gitmodules index 1ee8d73..e69de29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "incl/redai_image"] - path = incl/redai_image - url = ssh://git@git.iqser.com:2222/rr/redai_image.git diff --git a/Dockerfile b/Dockerfile index a8c37bd..fedf264 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,23 +1,19 @@ -ARG BASE_ROOT="nexus.iqser.com:5001/red/" -ARG VERSION_TAG="latest" - -FROM ${BASE_ROOT}image-prediction-base:${VERSION_TAG} +FROM image-prediction-base WORKDIR /app/service COPY src src COPY data data COPY image_prediction image_prediction -COPY incl/redai_image/redai incl/redai_image/redai COPY setup.py setup.py COPY requirements.txt requirements.txt COPY config.yaml config.yaml +COPY banner.txt banner.txt # Install dependencies differing from base image. RUN python3 -m pip install -r requirements.txt RUN python3 -m pip install -e . -RUN python3 -m pip install -e incl/redai_image/redai EXPOSE 5000 EXPOSE 8080 diff --git a/Dockerfile_tests b/Dockerfile_tests new file mode 100644 index 0000000..a05a4a3 --- /dev/null +++ b/Dockerfile_tests @@ -0,0 +1,23 @@ +ARG BASE_ROOT="nexus.iqser.com:5001/red/" +ARG VERSION_TAG="dev" + +FROM ${BASE_ROOT}image-prediction:${VERSION_TAG} + +WORKDIR /app/service + +COPY src src +COPY data data +COPY image_prediction image_prediction +COPY setup.py setup.py +COPY requirements.txt requirements.txt +COPY config.yaml config.yaml + +# Install module & dependencies +RUN python3 -m pip install -e . +RUN python3 -m pip install -r requirements.txt + +RUN apt update --yes +RUN apt install vim --yes +RUN apt install poppler-utils --yes + +CMD coverage run -m pytest test/ --tb=native -q -s -vvv -x && coverage combine && coverage report -m && coverage xml diff --git a/README.md b/README.md index f913627..6280f91 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,140 @@ -### Building +### Setup Build base image ```bash -setup/docker.sh -``` - -Build head image -```bash -docker build -f Dockerfile -t image-prediction . --build-arg BASE_ROOT="" +docker build -f Dockerfile_base -t image-prediction-base . +docker build -f Dockerfile -t image-prediction . ``` ### Usage +#### Without Docker + + +```bash +py scripts/run_pipeline.py /path/to/a/pdf +``` + +#### With Docker + Shell 1 ```bash -docker run --rm --net=host --rm image-prediction +docker run --rm --net=host image-prediction ``` Shell 2 ```bash -python scripts/pyinfra_mock.py --pdf_path /path/to/a/pdf +python scripts/pyinfra_mock.py /path/to/a/pdf ``` + +### Tests + +Run for example this command to execute all tests and get a coverage report: + +```bash +coverage run -m pytest test --tb=native -q -s -vvv -x && coverage combine && coverage report -m +``` + +After having built the service container as specified above, you can also run tests in a container as follows: + +```bash +./run_tests.sh +``` + +### Message Body Formats + + +#### Request Format + +The request messages need to provide the fields `"dossierId"` and `"fileId"`. A request should look like this: + +```json +{ + "dossierId": "", + "fileId": "" +} +``` + +Any additional keys are ignored. + + +#### Response Format + +Response bodies contain information about the identified class of the image, the confidence of the classification, the +position and size of the image as well as the results of additional convenience filters which can be configured through +environment variables. A response body looks like this: + +```json +{ + "dossierId": "debug", + "fileId": "13ffa9851740c8d20c4c7d1706d72f2a", + "data": [...] +} +``` + +An image metadata record (entry in `"data"` field of a response body) looks like this: + +```json +{ + "classification": { + "label": "logo", + "probabilities": { + "logo": 1.0, + "signature": 1.1599173226749333e-17, + "other": 2.994595513398207e-23, + "formula": 4.352109377281029e-31 + } + }, + "position": { + "x1": 475.95, + "x2": 533.4, + "y1": 796.47, + "y2": 827.62, + "pageNumber": 6 + }, + "geometry": { + "width": 57.44999999999999, + "height": 31.149999999999977 + }, + "alpha": false, + "filters": { + "geometry": { + "imageSize": { + "quotient": 0.05975350599135938, + "tooLarge": false, + "tooSmall": false + }, + "imageFormat": { + "quotient": 1.8443017656500813, + "tooTall": false, + "tooWide": false + } + }, + "probability": { + "unconfident": false + }, + "allPassed": true + } +} +``` + + +## Configuration + +A configuration file is located under `config.yaml`. All relevant variables can be configured via +exporting environment variables. + +| __Environment Variable__ | Default | Description | +|------------------------------------|------------------------------------|----------------------------------------------------------------------------------------| +| __LOGGING_LEVEL_ROOT__ | "INFO" | Logging level for log file messages | +| __VERBOSE__ | *true* | Service prints document processing progress to stdout | +| __BATCH_SIZE__ | 16 | Number of images in memory simultaneously per service instance | +| __RUN_ID__ | "fabfb1f192c745369b88cab34471aba7" | The ID of the mlflow run to load the image classifier from | +| __MIN_REL_IMAGE_SIZE__ | 0.05 | Minimally permissible image size to page size ratio | +| __MAX_REL_IMAGE_SIZE__ | 0.75 | Maximally permissible image size to page size ratio | +| __MIN_IMAGE_FORMAT__ | 0.1 | Minimally permissible image width to height ratio | +| __MAX_IMAGE_FORMAT__ | 10 | Maximally permissible image width to height ratio | + +See also: https://git.iqser.com/projects/RED/repos/helm/browse/redaction/templates/image-service-v2 diff --git a/bamboo-specs/src/main/java/buildjob/PlanSpec.java b/bamboo-specs/src/main/java/buildjob/PlanSpec.java index 269cd70..590c4e0 100644 --- a/bamboo-specs/src/main/java/buildjob/PlanSpec.java +++ b/bamboo-specs/src/main/java/buildjob/PlanSpec.java @@ -73,8 +73,8 @@ public class PlanSpec { project(), SERVICE_NAME, new BambooKey(SERVICE_KEY)) .description("Docker build for image-prediction.") - // .variables() - .stages(new Stage("Build Stage") + .stages( + new Stage("Build Stage") .jobs( new Job("Build Job", new BambooKey("BUILD")) .tasks( @@ -84,9 +84,6 @@ public class PlanSpec { new VcsCheckoutTask() .description("Checkout default repository.") .checkoutItems(new CheckoutItem().defaultRepository()), - new VcsCheckoutTask() - .description("Checkout redai_image research repository.") - .checkoutItems(new CheckoutItem().repository("RR / redai_image").path("redai_image")), new ScriptTask() .description("Set config and keys.") .inlineBody("mkdir -p ~/.ssh\n" + @@ -102,7 +99,9 @@ public class PlanSpec { .dockerConfiguration( new DockerConfiguration() .image("nexus.iqser.com:5001/infra/release_build:4.2.0") - .volume("/var/run/docker.sock", "/var/run/docker.sock")), + .volume("/var/run/docker.sock", "/var/run/docker.sock"))), + new Stage("Sonar Stage") + .jobs( new Job("Sonar Job", new BambooKey("SONAR")) .tasks( new CleanWorkingDirectoryTask() @@ -111,9 +110,6 @@ public class PlanSpec { new VcsCheckoutTask() .description("Checkout default repository.") .checkoutItems(new CheckoutItem().defaultRepository()), - new VcsCheckoutTask() - .description("Checkout redai_image repository.") - .checkoutItems(new CheckoutItem().repository("RR / redai_image").path("redai_image")), new ScriptTask() .description("Set config and keys.") .inlineBody("mkdir -p ~/.ssh\n" + diff --git a/bamboo-specs/src/main/resources/scripts/docker-build.sh b/bamboo-specs/src/main/resources/scripts/docker-build.sh index f17638f..4cc2704 100755 --- a/bamboo-specs/src/main/resources/scripts/docker-build.sh +++ b/bamboo-specs/src/main/resources/scripts/docker-build.sh @@ -10,10 +10,11 @@ python3 -m pip install --upgrade pip pip install dvc pip install 'dvc[ssh]' +echo "Pulling dvc data" dvc pull echo "index-url = https://${bamboo_nexus_user}:${bamboo_nexus_password}@nexus.iqser.com/repository/python-combind/simple" >> pip.conf -docker build -f Dockerfile_base -t nexus.iqser.com:5001/red/$SERVICE_NAME_BASE:${bamboo_version_tag} . -docker build -f Dockerfile -t nexus.iqser.com:5001/red/$SERVICE_NAME:${bamboo_version_tag} --build-arg VERSION_TAG=${bamboo_version_tag} . +docker build -f Dockerfile_base -t $SERVICE_NAME_BASE . +docker build -f Dockerfile -t nexus.iqser.com:5001/red/$SERVICE_NAME:${bamboo_version_tag} . echo "${bamboo_nexus_password}" | docker login --username "${bamboo_nexus_user}" --password-stdin nexus.iqser.com:5001 docker push nexus.iqser.com:5001/red/$SERVICE_NAME:${bamboo_version_tag} diff --git a/bamboo-specs/src/main/resources/scripts/sonar-scan.sh b/bamboo-specs/src/main/resources/scripts/sonar-scan.sh index 6381dcd..9286748 100755 --- a/bamboo-specs/src/main/resources/scripts/sonar-scan.sh +++ b/bamboo-specs/src/main/resources/scripts/sonar-scan.sh @@ -6,11 +6,17 @@ export JAVA_HOME=/usr/bin/sonar-scanner/jre python3 -m venv build_venv source build_venv/bin/activate python3 -m pip install --upgrade pip +python3 -m pip install dependency-check +python3 -m pip install coverage -echo "dev setup for unit test and coverage 💖" +echo "coverage report generation" -pip install -e . -pip install -r requirements.txt +bash run_tests.sh + +if [ ! -f reports/coverage.xml ] +then + exit 1 +fi SERVICE_NAME=$1 diff --git a/banner.txt b/banner.txt new file mode 100644 index 0000000..2ef04ff --- /dev/null +++ b/banner.txt @@ -0,0 +1,11 @@ ++----------------------------------------------------+ +| ___ | +| __/_ `. .-"""-. | +|_._ _,-'""`-._ \_,` | \-' / )`-')| +|(,-.`._,'( |\`-/| "") `"` \ ((`"` | +| `-.-' \ )-`( , o o) ___Y , .'7 /| | +| `- \`_`"'- (_,___/...-` (_/_/ | +| | ++----------------------------------------------------+ +| Image Classification Service | ++----------------------------------------------------+ \ No newline at end of file diff --git a/config.yaml b/config.yaml index dbc5b87..6a6111a 100644 --- a/config.yaml +++ b/config.yaml @@ -1,20 +1,18 @@ webserver: host: $SERVER_HOST|"127.0.0.1" # webserver address port: $SERVER_PORT|5000 # webserver port - mode: $SERVER_MODE|production # webserver mode: {development, production} service: - logging_level: $LOGGING_LEVEL_ROOT|DEBUG # Logging level for service logger - progressbar: True # Whether a progress bar over the pages of a document is displayed while processing - batch_size: $BATCH_SIZE|32 # Number of images in memory simultaneously + logging_level: $LOGGING_LEVEL_ROOT|INFO # Logging level for service logger verbose: $VERBOSE|True # Service prints document processing progress to stdout - run_id: $RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the model from + batch_size: $BATCH_SIZE|16 # Number of images in memory simultaneously + mlflow_run_id: $MLFLOW_RUN_ID|fabfb1f192c745369b88cab34471aba7 # The ID of the mlflow run to load the service_estimator from -# These variables control filters that are applied to either images, image metadata or model predictions. The filter -# result values are reported in the service responses. For convenience the response to a request contains a -# "filters.allPassed" field, which is set to false if any of the filters returned values did not meet its specified -# required value. +# These variables control filters that are applied to either images, image metadata or service_estimator predictions. +# The filter result values are reported in the service responses. For convenience the response to a request contains a +# "filters.allPassed" field, which is set to false if any of the values returned by the filters did not meet its +# specified required value. filters: image_to_page_quotient: # Image size to page size ratio (ratio of geometric means of areas) diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..c9213f4 --- /dev/null +++ b/data/.gitignore @@ -0,0 +1 @@ +/mlruns diff --git a/data/base_weights.h5.dvc b/data/base_weights.h5.dvc deleted file mode 100644 index 9f07d13..0000000 --- a/data/base_weights.h5.dvc +++ /dev/null @@ -1,4 +0,0 @@ -outs: -- md5: 6d0186c1f25e889d531788f168fa6cf0 - size: 16727296 - path: base_weights.h5 diff --git a/data/mlruns.dvc b/data/mlruns.dvc index d390fed..c1050a5 100644 --- a/data/mlruns.dvc +++ b/data/mlruns.dvc @@ -1,5 +1,5 @@ outs: -- md5: d1c708270bab6fcd344d4a8b05d1103d.dir - size: 150225383 - nfiles: 178 +- md5: ad061d607f615afc149643f62dbf37cc.dir + size: 166952700 + nfiles: 179 path: mlruns diff --git a/doc/tests.drawio b/doc/tests.drawio new file mode 100644 index 0000000..c335abc --- /dev/null +++ b/doc/tests.drawio @@ -0,0 +1 @@ +1ZZRT6QwEMc/DY8mQHdRX93z9JLbmNzGmNxbQ0daLQzpDrL46a/IsCzinneJcd0XaP+dtsN/fkADscg3V06WeokKbBCHahOIb0Ecnydzf22FphPmyXknZM6oTooGYWWegcWQ1cooWI8CCdGSKcdiikUBKY006RzW47B7tONdS5nBRFil0k7VO6NId+rZPBz0azCZ7neOQh7JZR/MwlpLhfWOJC4DsXCI1LXyzQJs613vSzfv+57RbWIOCvqXCZqW9PBref27aZ7xsQ5vTn/cnvAqT9JW/MCwJuNzR8dZU9Nb4bAqFLSrhYG4qLUhWJUybUdrX3uvacqt70W+yeuCI9jsTTja2uDxAcyBXONDeILonWN04hn366EQUR+jd4qQsCa59tl26cEe32CH/sOt+TueoCONGRbS/kQs2YkHIGoYbFkRvuUTqAmFr1zyu2LlUvhLdjG/HtJlQO/VfOq6AyvJPI3z+HAL4wlwpbp/2V0qODxzUTJmLjo4c8nEkxaWFXcLLPzt4ithKI4BQzHBMOc/l8UvAeLrj9/hQTw9NhBnxwDibB+IB+ZvdvZ5/PnucAx6Gds5S4rLPw== \ No newline at end of file diff --git a/test/unit_tests/__init__.py b/image_prediction/classifier/__init__.py similarity index 100% rename from test/unit_tests/__init__.py rename to image_prediction/classifier/__init__.py diff --git a/image_prediction/classifier/classifier.py b/image_prediction/classifier/classifier.py new file mode 100644 index 0000000..8e752cf --- /dev/null +++ b/image_prediction/classifier/classifier.py @@ -0,0 +1,35 @@ +from typing import List, Union, Tuple + +import numpy as np +from PIL.Image import Image +from funcy import rcompose + +from image_prediction.estimator.adapter.adapter import EstimatorAdapter +from image_prediction.label_mapper.mapper import LabelMapper +from image_prediction.utils import get_logger + +logger = get_logger() + + +class Classifier: + def __init__(self, estimator_adapter: EstimatorAdapter, label_mapper: LabelMapper): + """Abstraction layer over different estimator backends (e.g. keras or scikit-learn). For each backend to be used + an EstimatorAdapter must be implemented. + + Args: + estimator_adapter: adapter for a given estimator backend + """ + self.__estimator_adapter = estimator_adapter + self.__label_mapper = label_mapper + self.__pipe = rcompose(self.__estimator_adapter, self.__label_mapper) + + def predict(self, batch: Union[np.array, Tuple[Image]]) -> List[str]: + + if isinstance(batch, np.ndarray) and batch.shape[0] == 0: + return [] + + return self.__pipe(batch) + + def __call__(self, batch: np.array) -> List[str]: + logger.debug("Classifier.predict") + return self.predict(batch) diff --git a/image_prediction/classifier/image_classifier.py b/image_prediction/classifier/image_classifier.py new file mode 100644 index 0000000..f01cfd4 --- /dev/null +++ b/image_prediction/classifier/image_classifier.py @@ -0,0 +1,32 @@ +from itertools import chain +from typing import Iterable + +from PIL.Image import Image +from funcy import rcompose, chunks + +from image_prediction.classifier.classifier import Classifier +from image_prediction.estimator.preprocessor.preprocessor import Preprocessor +from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor +from image_prediction.utils import get_logger + +logger = get_logger() + + +class ImageClassifier: + """Combines a classifier with a preprocessing pipeline: Receives images, chunks into batches, converts to tensors, + applies transformations and finally sends to internal classifier. + """ + + def __init__(self, classifier: Classifier, preprocessor: Preprocessor = None): + self.estimator = classifier + self.preprocessor = preprocessor if preprocessor else IdentityPreprocessor() + self.pipe = rcompose(self.preprocessor, self.estimator) + + def predict(self, images: Iterable[Image], batch_size=16): + batches = chunks(batch_size, images) + predictions = chain.from_iterable(map(self.pipe, batches)) + return predictions + + def __call__(self, images: Iterable[Image], batch_size=16): + logger.debug("ImageClassifier.predict") + yield from self.predict(images, batch_size=batch_size) diff --git a/image_prediction/compositor/__init__.py b/image_prediction/compositor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/compositor/compositor.py b/image_prediction/compositor/compositor.py new file mode 100644 index 0000000..5a3c49a --- /dev/null +++ b/image_prediction/compositor/compositor.py @@ -0,0 +1,16 @@ +from funcy import rcompose + +from image_prediction.transformer.transformer import Transformer +from image_prediction.utils import get_logger + +logger = get_logger() + + +class TransformerCompositor(Transformer): + def __init__(self, formatter: Transformer, *formatters: Transformer): + formatters = (formatter, *formatters) + self.pipe = rcompose(*formatters) + + def transform(self, obj): + logger.debug("TransformerCompositor.transform") + return self.pipe(obj) diff --git a/image_prediction/config.py b/image_prediction/config.py index f37658f..4696191 100644 --- a/image_prediction/config.py +++ b/image_prediction/config.py @@ -18,12 +18,12 @@ class DotIndexable: def __getattr__(self, item): return _get_item_and_maybe_make_dotindexable(self.x, item) - def __setitem__(self, key, value): - self.x[key] = value - def __repr__(self): return self.x.__repr__() + def __getitem__(self, item): + return self.__getattr__(item) + class Config: def __init__(self, config_path): diff --git a/image_prediction/default_objects.py b/image_prediction/default_objects.py new file mode 100644 index 0000000..97c089e --- /dev/null +++ b/image_prediction/default_objects.py @@ -0,0 +1,38 @@ +from funcy import juxt + +from image_prediction.classifier.classifier import Classifier +from image_prediction.classifier.image_classifier import ImageClassifier +from image_prediction.compositor.compositor import TransformerCompositor +from image_prediction.estimator.adapter.adapter import EstimatorAdapter +from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter +from image_prediction.formatter.formatters.enum import EnumFormatter +from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor +from image_prediction.label_mapper.mappers.probability import ProbabilityMapper +from image_prediction.model_loader.loader import ModelLoader +from image_prediction.model_loader.loaders.mlflow import MlflowConnector +from image_prediction.redai_adapter.mlflow import MlflowModelReader +from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer +from image_prediction.transformer.transformers.response import ResponseTransformer + + +def get_mlflow_model_loader(mlruns_dir): + model_loader = ModelLoader(MlflowConnector(MlflowModelReader(mlruns_dir))) + return model_loader + + +def get_image_classifier(model_loader, model_identifier): + model, classes = juxt(model_loader.load_model, model_loader.load_classes)(model_identifier) + return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes))) + + +def get_extractor(**kwargs): + image_extractor = ParsablePDFImageExtractor(**kwargs) + + return image_extractor + + +def get_formatter(): + formatter = TransformerCompositor( + PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter() + ) + return formatter diff --git a/image_prediction/estimator/__init__.py b/image_prediction/estimator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/estimator/adapter/__init__.py b/image_prediction/estimator/adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/estimator/adapter/adapter.py b/image_prediction/estimator/adapter/adapter.py new file mode 100644 index 0000000..8aac9c9 --- /dev/null +++ b/image_prediction/estimator/adapter/adapter.py @@ -0,0 +1,15 @@ +from image_prediction.utils import get_logger + +logger = get_logger() + + +class EstimatorAdapter: + def __init__(self, estimator): + self.estimator = estimator + + def predict(self, batch): + return self.estimator(batch) + + def __call__(self, batch): + logger.debug("EstimatorAdapter.predict") + return self.predict(batch) diff --git a/image_prediction/estimator/adapter/adapters/__init__.py b/image_prediction/estimator/adapter/adapters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/estimator/preprocessor/__init__.py b/image_prediction/estimator/preprocessor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/estimator/preprocessor/preprocessor.py b/image_prediction/estimator/preprocessor/preprocessor.py new file mode 100644 index 0000000..58af954 --- /dev/null +++ b/image_prediction/estimator/preprocessor/preprocessor.py @@ -0,0 +1,10 @@ +import abc + + +class Preprocessor(abc.ABC): + @abc.abstractmethod + def preprocess(self, batch): + raise NotImplementedError + + def __call__(self, batch): + return self.preprocess(batch) diff --git a/image_prediction/estimator/preprocessor/preprocessors/__init__.py b/image_prediction/estimator/preprocessor/preprocessors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/estimator/preprocessor/preprocessors/basic.py b/image_prediction/estimator/preprocessor/preprocessors/basic.py new file mode 100644 index 0000000..03e94e6 --- /dev/null +++ b/image_prediction/estimator/preprocessor/preprocessors/basic.py @@ -0,0 +1,10 @@ +from image_prediction.estimator.preprocessor.preprocessor import Preprocessor +from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor + + +class BasicPreprocessor(Preprocessor): + """Converts images to tensors""" + + @staticmethod + def preprocess(images): + return images_to_batch_tensor(images) diff --git a/image_prediction/estimator/preprocessor/preprocessors/identity.py b/image_prediction/estimator/preprocessor/preprocessors/identity.py new file mode 100644 index 0000000..199b1b0 --- /dev/null +++ b/image_prediction/estimator/preprocessor/preprocessors/identity.py @@ -0,0 +1,10 @@ +from image_prediction.estimator.preprocessor.preprocessor import Preprocessor + + +class IdentityPreprocessor(Preprocessor): + @staticmethod + def preprocess(images): + return images + + def __call__(self, images): + return self.preprocess(images) diff --git a/image_prediction/estimator/preprocessor/utils.py b/image_prediction/estimator/preprocessor/utils.py new file mode 100644 index 0000000..dbab144 --- /dev/null +++ b/image_prediction/estimator/preprocessor/utils.py @@ -0,0 +1,10 @@ +import numpy as np +from PIL.Image import Image + + +def image_to_normalized_tensor(image: Image) -> np.ndarray: + return np.array(image) / 255 + + +def images_to_batch_tensor(images) -> np.ndarray: + return np.array(list(map(image_to_normalized_tensor, images))) diff --git a/image_prediction/exceptions.py b/image_prediction/exceptions.py new file mode 100644 index 0000000..1b88f0d --- /dev/null +++ b/image_prediction/exceptions.py @@ -0,0 +1,34 @@ +class UnknownEstimatorAdapter(ValueError): + pass + + +class UnknownImageExtractor(ValueError): + pass + + +class UnknownModelLoader(ValueError): + pass + + +class UnknownDatabaseType(ValueError): + pass + + +class UnknownLabelFormat(ValueError): + pass + + +class UnexpectedLabelFormat(ValueError): + pass + + +class IncorrectInstantiation(RuntimeError): + pass + + +class IntentionalTestException(RuntimeError): + pass + + +class InvalidBox(Exception): + pass diff --git a/image_prediction/extraction.py b/image_prediction/extraction.py new file mode 100644 index 0000000..b996ed7 --- /dev/null +++ b/image_prediction/extraction.py @@ -0,0 +1,13 @@ +from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor + + +def extract_images_from_pdf(pdf, extractor=None): + + if not extractor: + extractor = ParsablePDFImageExtractor() + + try: + images_extracted, metadata_extracted = zip(*extractor(pdf)) + return images_extracted, metadata_extracted + except ValueError: + return [], [] diff --git a/image_prediction/flask.py b/image_prediction/flask.py index 5cf40c2..7ab4005 100644 --- a/image_prediction/flask.py +++ b/image_prediction/flask.py @@ -1,4 +1,5 @@ import multiprocessing +import traceback from typing import Callable from flask import Flask, request, jsonify @@ -8,8 +9,30 @@ from image_prediction.utils import get_logger logger = get_logger() -def make_prediction_server(predict_fn: Callable): +def run_in_process(func): + p = multiprocessing.Process(target=func) + p.start() + p.join() + +def wrap_in_process(func_to_wrap): + def build_function_and_run_in_process(*args, **kwargs): + def func(): + try: + result = func_to_wrap(*args, **kwargs) + return_dict["result"] = result + except: + logger.error(traceback.format_exc()) + + manager = multiprocessing.Manager() + return_dict = manager.dict() + run_in_process(func) + return return_dict.get("result", None) + + return build_function_and_run_in_process + + +def make_prediction_server(predict_fn: Callable): app = Flask(__name__) @app.route("/ready", methods=["GET"]) @@ -24,42 +47,28 @@ def make_prediction_server(predict_fn: Callable): resp.status_code = 200 return resp + def __failure(): + response = jsonify("Analysis failed") + response.status_code = 500 + return response + + @app.route("/predict", methods=["POST"]) @app.route("/", methods=["POST"]) def predict(): - def predict_fn_wrapper(pdf, return_dict): - return_dict["result"] = predict_fn(pdf) - def process(): - # Tensorflow does not free RAM. Workaround is running model in process. - # https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution - pdf = request.data - manager = multiprocessing.Manager() - return_dict = manager.dict() - p = multiprocessing.Process( - target=predict_fn_wrapper, - args=( - pdf, - return_dict, - ), - ) - p.start() - p.join() - try: - return dict(return_dict)["result"] - except KeyError: - raise + # Tensorflow does not free RAM. Workaround: Run prediction function (which instantiates a model) in sub-process. + # See: https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution + predict_fn_wrapped = wrap_in_process(predict_fn) - logger.debug("Running predictor on document...") - try: - predictions = process() + logger.info("Analysing...") + predictions = predict_fn_wrapped(request.data) + + if predictions: response = jsonify(predictions) logger.info("Analysis completed.") return response - except Exception as err: + else: logger.error("Analysis failed.") - logger.exception(err) - response = jsonify("Analysis failed.") - response.status_code = 500 - return response + return __failure() return app diff --git a/image_prediction/formatter/__init__.py b/image_prediction/formatter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/formatter/formatter.py b/image_prediction/formatter/formatter.py new file mode 100644 index 0000000..3f3a1f8 --- /dev/null +++ b/image_prediction/formatter/formatter.py @@ -0,0 +1,15 @@ +import abc + +from image_prediction.transformer.transformer import Transformer + + +class Formatter(Transformer): + @abc.abstractmethod + def format(self, obj): + raise NotImplementedError + + def transform(self, obj): + raise NotImplementedError() + + def __call__(self, obj): + return self.format(obj) diff --git a/image_prediction/formatter/formatters/__init__.py b/image_prediction/formatter/formatters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/formatter/formatters/camel_case.py b/image_prediction/formatter/formatters/camel_case.py new file mode 100644 index 0000000..caa240a --- /dev/null +++ b/image_prediction/formatter/formatters/camel_case.py @@ -0,0 +1,11 @@ +from image_prediction.formatter.formatters.key_formatter import KeyFormatter + + +class Snake2CamelCaseKeyFormatter(KeyFormatter): + def format_key(self, key): + + if isinstance(key, str): + head, *tail = key.split("_") + return head + "".join(map(str.title, tail)) + else: + return key diff --git a/image_prediction/formatter/formatters/enum.py b/image_prediction/formatter/formatters/enum.py new file mode 100644 index 0000000..45e5629 --- /dev/null +++ b/image_prediction/formatter/formatters/enum.py @@ -0,0 +1,23 @@ +from enum import Enum + +from image_prediction.formatter.formatters.key_formatter import KeyFormatter + + +class EnumFormatter(KeyFormatter): + def format_key(self, key): + return key.value if isinstance(key, Enum) else key + + def transform(self, obj): + raise NotImplementedError + + +class ReverseEnumFormatter(KeyFormatter): + def __init__(self, enum): + self.enum = enum + self.reverse_enum = {e.value: e for e in enum} + + def format_key(self, key): + return self.reverse_enum.get(key, key) + + def transform(self, obj): + raise NotImplementedError diff --git a/image_prediction/formatter/formatters/identity.py b/image_prediction/formatter/formatters/identity.py new file mode 100644 index 0000000..0eaca6a --- /dev/null +++ b/image_prediction/formatter/formatters/identity.py @@ -0,0 +1,6 @@ +from image_prediction.formatter.formatter import Formatter + + +class IdentityFormatter(Formatter): + def format(self, obj): + return obj diff --git a/image_prediction/formatter/formatters/key_formatter.py b/image_prediction/formatter/formatters/key_formatter.py new file mode 100644 index 0000000..e2ed3a3 --- /dev/null +++ b/image_prediction/formatter/formatters/key_formatter.py @@ -0,0 +1,28 @@ +import abc +from typing import Iterable + +from image_prediction.formatter.formatter import Formatter + + +class KeyFormatter(Formatter): + @abc.abstractmethod + def format_key(self, key): + raise NotImplementedError + + def __format(self, data): + + # If we wanted to do this properly, we would need handlers for all expected types and dispatch based + # on a type comparison. This is too much engineering for the limited use-case of this class though. + if isinstance(data, Iterable) and not isinstance(data, dict) and not isinstance(data, str): + f = map(self.__format, data) + return type(data)(f) if not isinstance(data, map) else f + + if not isinstance(data, dict): + return data + + keys_formatted = list(map(self.format_key, data)) + + return dict(zip(keys_formatted, map(self.__format, data.values()))) + + def format(self, data): + return self.__format(data) diff --git a/image_prediction/image_extractor/__init__.py b/image_prediction/image_extractor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/image_extractor/extractor.py b/image_prediction/image_extractor/extractor.py new file mode 100644 index 0000000..ca6392e --- /dev/null +++ b/image_prediction/image_extractor/extractor.py @@ -0,0 +1,19 @@ +import abc +from collections import namedtuple +from typing import Iterable + +from image_prediction.utils import get_logger + +ImageMetadataPair = namedtuple("ImageMetadataPair", ["image", "metadata"]) + +logger = get_logger() + + +class ImageExtractor(abc.ABC): + @abc.abstractmethod + def extract(self, obj) -> Iterable[ImageMetadataPair]: + raise NotImplementedError + + def __call__(self, obj, **kwargs): + logger.debug("ImageExtractor.extract") + return self.extract(obj, **kwargs) diff --git a/image_prediction/image_extractor/extractors/__init__.py b/image_prediction/image_extractor/extractors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/image_extractor/extractors/mock.py b/image_prediction/image_extractor/extractors/mock.py new file mode 100644 index 0000000..86ced9a --- /dev/null +++ b/image_prediction/image_extractor/extractors/mock.py @@ -0,0 +1,7 @@ +from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair + + +class ImageExtractorMock(ImageExtractor): + def extract(self, image_container): + for i, image in enumerate(image_container): + yield ImageMetadataPair(image, {"image_id": i}) diff --git a/image_prediction/image_extractor/extractors/parsable.py b/image_prediction/image_extractor/extractors/parsable.py new file mode 100644 index 0000000..a022396 --- /dev/null +++ b/image_prediction/image_extractor/extractors/parsable.py @@ -0,0 +1,179 @@ +import atexit +import io +from functools import partial, lru_cache +from itertools import chain, starmap, filterfalse +from operator import itemgetter +from typing import List + +import fitz +from PIL import Image +from funcy import rcompose, merge, pluck, curry, compose + +from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair +from image_prediction.info import Info +from image_prediction.stitching.stitching import stitch_pairs +from image_prediction.stitching.utils import validate_box_coords, validate_box_size +from image_prediction.utils.generic import lift + + +class ParsablePDFImageExtractor(ImageExtractor): + def __init__(self, verbose=False, tolerance=0): + """ + + Args: + verbose: Whether to show progressbar + tolerance: The tolerance in pixels for the distance images beyond which they will not be stitched together + """ + self.doc: fitz.fitz.Document = None + self.verbose = verbose + self.tolerance = tolerance + + def extract(self, pdf: bytes, page_range: range = None): + self.doc = fitz.Document(stream=pdf) + + pages = extract_pages(self.doc, page_range) if page_range else self.doc + + image_metadata_pairs = chain.from_iterable(map(self.__process_images_on_page, pages)) + + yield from image_metadata_pairs + + def __process_images_on_page(self, page: fitz.fitz.Page): + images = get_images_on_page(self.doc, page) + metadata = get_metadata_for_images_on_page(self.doc, page) + clear_caches() + + image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata))) + image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) + + yield from image_metadata_pairs + + +def extract_pages(doc, page_range): + page_range = range(page_range.start + 1, page_range.stop + 1) + pages = map(doc.load_page, page_range) + + yield from pages + + +@lru_cache(maxsize=None) +def get_images_on_page(doc, page: fitz.Page): + image_infos = get_image_infos(page) + xrefs = map(itemgetter("xref"), image_infos) + images = map(partial(xref_to_image, doc), xrefs) + + yield from images + + +def get_metadata_for_images_on_page(doc, page: fitz.Page): + + metadata = map(get_image_metadata, get_image_infos(page)) + metadata = validate_coords_and_passthrough(metadata) + + metadata = filter_out_tiny_images(metadata) + metadata = validate_size_and_passthrough(metadata) + + metadata = add_page_metadata(page, metadata) + + metadata = add_alpha_channel_info(doc, page, metadata) + + yield from metadata + + +@lru_cache(maxsize=None) +def get_image_infos(page: fitz.Page) -> List[dict]: + return page.get_image_info(xrefs=True) + + +@lru_cache(maxsize=None) +def xref_to_image(doc, xref) -> Image: + maybe_image = load_image_handle_from_xref(doc, xref) + return Image.open(io.BytesIO(maybe_image["image"])) if maybe_image else None + + +def get_image_metadata(image_info): + + x1, y1, x2, y2 = map(rounder, image_info["bbox"]) + + width = abs(x2 - x1) + height = abs(y2 - y1) + + return { + Info.WIDTH: width, + Info.HEIGHT: height, + Info.X1: x1, + Info.X2: x2, + Info.Y1: y1, + Info.Y2: y2, + } + + +def validate_coords_and_passthrough(metadata): + yield from map(validate_box_coords, metadata) + + +def filter_out_tiny_images(metadata): + yield from filterfalse(tiny, metadata) + + +def validate_size_and_passthrough(metadata): + yield from map(validate_box_size, metadata) + + +def add_page_metadata(page, metadata): + yield from map(partial(merge, get_page_metadata(page)), metadata) + + +def add_alpha_channel_info(doc, page, metadata): + + page_to_xrefs = compose(curry(pluck)("xref"), get_image_infos) + xref_to_alpha = partial(has_alpha_channel, doc) + page_to_alpha_value_per_image = compose(lift(xref_to_alpha), page_to_xrefs) + alpha_to_dict = compose(dict, lambda a: [(Info.ALPHA, a)]) + page_to_alpha_mapping_per_image = compose(lift(alpha_to_dict), page_to_alpha_value_per_image) + + metadata = starmap(merge, zip(page_to_alpha_mapping_per_image(page), metadata)) + + yield from metadata + + +@lru_cache(maxsize=None) +def load_image_handle_from_xref(doc, xref): + return doc.extract_image(xref) + + +rounder = rcompose(round, int) + + +def get_page_metadata(page): + page_width, page_height = map(rounder, page.mediabox_size) + + return { + Info.PAGE_WIDTH: page_width, + Info.PAGE_HEIGHT: page_height, + Info.PAGE_IDX: page.number, + } + + +def has_alpha_channel(doc, xref): + + maybe_image = load_image_handle_from_xref(doc, xref) + maybe_smask = maybe_image["smask"] if maybe_image else None + + if maybe_smask: + return any([doc.extract_image(maybe_smask) is not None, bool(fitz.Pixmap(doc, maybe_smask).alpha)]) + else: + return bool(fitz.Pixmap(doc, xref).alpha) + + +def tiny(metadata): + return metadata[Info.WIDTH] * metadata[Info.HEIGHT] <= 4 + + +def clear_caches(): + get_image_infos.cache_clear() + load_image_handle_from_xref.cache_clear() + get_images_on_page.cache_clear() + xref_to_image.cache_clear() + + +atexit.register(clear_caches) diff --git a/image_prediction/info.py b/image_prediction/info.py new file mode 100644 index 0000000..344274a --- /dev/null +++ b/image_prediction/info.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class Info(Enum): + PAGE_WIDTH = "page_width" + PAGE_HEIGHT = "page_height" + PAGE_IDX = "page_idx" + WIDTH = "width" + HEIGHT = "height" + X1 = "x1" + X2 = "x2" + Y1 = "y1" + Y2 = "y2" + ALPHA = "alpha" diff --git a/image_prediction/label_mapper/__init__.py b/image_prediction/label_mapper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/label_mapper/mapper.py b/image_prediction/label_mapper/mapper.py new file mode 100644 index 0000000..4cdff9e --- /dev/null +++ b/image_prediction/label_mapper/mapper.py @@ -0,0 +1,10 @@ +import abc + + +class LabelMapper(abc.ABC): + @abc.abstractmethod + def map_labels(self, items): + raise NotImplementedError + + def __call__(self, items): + return self.map_labels(items) diff --git a/image_prediction/label_mapper/mappers/__init__.py b/image_prediction/label_mapper/mappers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/label_mapper/mappers/numeric.py b/image_prediction/label_mapper/mappers/numeric.py new file mode 100644 index 0000000..93b08cd --- /dev/null +++ b/image_prediction/label_mapper/mappers/numeric.py @@ -0,0 +1,20 @@ +from typing import Mapping, Iterable + +from image_prediction.exceptions import UnexpectedLabelFormat +from image_prediction.label_mapper.mapper import LabelMapper + + +class IndexMapper(LabelMapper): + def __init__(self, labels: Mapping[int, str]): + self.__labels = labels + + def __validate_index_label_format(self, index_label: int) -> None: + if not 0 <= index_label < len(self.__labels): + raise UnexpectedLabelFormat(f"Received index label '{index_label}' that has no associated string label.") + + def __map_label(self, index_label: int) -> str: + self.__validate_index_label_format(index_label) + return self.__labels[index_label] + + def map_labels(self, index_labels: Iterable[int]) -> Iterable[str]: + return map(self.__map_label, index_labels) diff --git a/image_prediction/label_mapper/mappers/probability.py b/image_prediction/label_mapper/mappers/probability.py new file mode 100644 index 0000000..b2a0e63 --- /dev/null +++ b/image_prediction/label_mapper/mappers/probability.py @@ -0,0 +1,39 @@ +from enum import Enum +from operator import itemgetter +from typing import Mapping, Iterable + +import numpy as np +from funcy import rcompose, rpartial + +from image_prediction.exceptions import UnexpectedLabelFormat +from image_prediction.label_mapper.mapper import LabelMapper + + +class ProbabilityMapperKeys(Enum): + LABEL = "label" + PROBABILITIES = "probabilities" + + +class ProbabilityMapper(LabelMapper): + def __init__(self, labels: Mapping[int, str]): + self.__labels = labels + # String conversion in the middle due to floating point precision issues. + # See: https://stackoverflow.com/questions/56820/round-doesnt-seem-to-be-rounding-properly + self.__rounder = rcompose(rpartial(round, 4), str, float) + + def __validate_array_label_format(self, probabilities: np.ndarray) -> None: + if not len(probabilities) == len(self.__labels): + raise UnexpectedLabelFormat( + f"Received fewer probabilities ({len(probabilities)}) than labels were passed ({len(self.__labels)})." + ) + + def __map_array(self, probabilities: np.ndarray) -> dict: + self.__validate_array_label_format(probabilities) + cls2prob = dict( + sorted(zip(self.__labels, list(map(self.__rounder, probabilities))), key=itemgetter(1), reverse=True) + ) + most_likely = [*cls2prob][0] + return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: cls2prob} + + def map_labels(self, probabilities: Iterable[np.ndarray]) -> Iterable[dict]: + return map(self.__map_array, probabilities) diff --git a/image_prediction/locations.py b/image_prediction/locations.py index 4bf8d4b..1f14c1a 100644 --- a/image_prediction/locations.py +++ b/image_prediction/locations.py @@ -1,10 +1,17 @@ -from os import path +"""Defines constant paths relative to the module root path.""" -MODULE_DIR = path.dirname(path.abspath(__file__)) -PACKAGE_ROOT_DIR = path.dirname(MODULE_DIR) +from pathlib import Path -CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml") +MODULE_DIR = Path(__file__).resolve().parents[0] -DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data") -MLRUNS_DIR = path.join(DATA_DIR, "mlruns") -BASE_WEIGHTS = path.join(DATA_DIR, "base_weights.h5") +PACKAGE_ROOT_DIR = MODULE_DIR.parents[0] + +CONFIG_FILE = PACKAGE_ROOT_DIR / "config.yaml" + +BANNER_FILE = PACKAGE_ROOT_DIR / "banner.txt" + +DATA_DIR = PACKAGE_ROOT_DIR / "data" + +MLRUNS_DIR = str(DATA_DIR / "mlruns") + +TEST_DATA_DIR = PACKAGE_ROOT_DIR / "test" / "data" diff --git a/image_prediction/model_loader/__init__.py b/image_prediction/model_loader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/model_loader/database/__init__.py b/image_prediction/model_loader/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/model_loader/database/connector.py b/image_prediction/model_loader/database/connector.py new file mode 100644 index 0000000..f265ad5 --- /dev/null +++ b/image_prediction/model_loader/database/connector.py @@ -0,0 +1,7 @@ +import abc + + +class DatabaseConnector(abc.ABC): + @abc.abstractmethod + def get_object(self, identifier): + raise NotImplementedError diff --git a/image_prediction/model_loader/database/connectors/__init__.py b/image_prediction/model_loader/database/connectors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/model_loader/database/connectors/mock.py b/image_prediction/model_loader/database/connectors/mock.py new file mode 100644 index 0000000..9dfcc50 --- /dev/null +++ b/image_prediction/model_loader/database/connectors/mock.py @@ -0,0 +1,9 @@ +from image_prediction.model_loader.database.connector import DatabaseConnector + + +class DatabaseConnectorMock(DatabaseConnector): + def __init__(self, store: dict): + self.store = store + + def get_object(self, identifier): + return self.store[identifier] diff --git a/image_prediction/model_loader/loader.py b/image_prediction/model_loader/loader.py new file mode 100644 index 0000000..7130a8a --- /dev/null +++ b/image_prediction/model_loader/loader.py @@ -0,0 +1,18 @@ +from functools import lru_cache + +from image_prediction.model_loader.database.connector import DatabaseConnector + + +class ModelLoader: + def __init__(self, database_connector: DatabaseConnector): + self.database_connector = database_connector + + @lru_cache(maxsize=None) + def __get_object(self, identifier): + return self.database_connector.get_object(identifier) + + def load_model(self, identifier): + return self.__get_object(identifier)["model"] + + def load_classes(self, identifier): + return self.__get_object(identifier)["classes"] diff --git a/image_prediction/model_loader/loaders/__init__.py b/image_prediction/model_loader/loaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/model_loader/loaders/mlflow.py b/image_prediction/model_loader/loaders/mlflow.py new file mode 100644 index 0000000..2bbd126 --- /dev/null +++ b/image_prediction/model_loader/loaders/mlflow.py @@ -0,0 +1,10 @@ +from image_prediction.model_loader.database.connector import DatabaseConnector +from image_prediction.redai_adapter.mlflow import MlflowModelReader + + +class MlflowConnector(DatabaseConnector): + def __init__(self, mlflow_reader: MlflowModelReader): + self.mlflow_reader = mlflow_reader + + def get_object(self, run_id): + return self.mlflow_reader[run_id] diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py new file mode 100644 index 0000000..721cbb9 --- /dev/null +++ b/image_prediction/pipeline.py @@ -0,0 +1,64 @@ +import os +from functools import partial +from itertools import chain, tee + +from funcy import rcompose, first, compose, second, chunks, identity +from tqdm import tqdm + +from image_prediction.config import CONFIG +from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor +from image_prediction.locations import MLRUNS_DIR +from image_prediction.utils.generic import lift, starlift + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + +def load_pipeline(**kwargs): + model_loader = get_mlflow_model_loader(MLRUNS_DIR) + model_identifier = CONFIG.service.mlflow_run_id + + pipeline = Pipeline(model_loader, model_identifier, **kwargs) + + return pipeline + + +def parallel(*fs): + return lambda *args: (f(a) for f, a in zip(fs, args)) + + +def star(f): + return lambda x: f(*x) + + +class Pipeline: + def __init__(self, model_loader, model_identifier, batch_size=16, verbose=True, **kwargs): + self.verbose = verbose + + extract = get_extractor(**kwargs) + classifier = get_image_classifier(model_loader, model_identifier) + reformat = get_formatter() + + split = compose(star(parallel(*map(lift, (first, second)))), tee) + classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) + pairwise_apply = compose(star, parallel) + join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip)) + + # +>--classify--v + # --extract-->--split--| |--join-->reformat + # +>--identity--^ + + self.pipe = rcompose( + extract, # ... image-metadata-pairs as a stream + split, # ... into an image stream and a metadata stream + pairwise_apply(classify, identity), # ... apply functions to the streams pairwise + join, # ... the streams by zipping + reformat, # ... the items + ) + + def __call__(self, pdf: bytes, page_range: range = None): + yield from tqdm( + self.pipe(pdf, page_range=page_range), + desc="Processing images from document", + unit=" images", + disable=not self.verbose, + ) diff --git a/image_prediction/predictor.py b/image_prediction/predictor.py deleted file mode 100644 index 8e83f2c..0000000 --- a/image_prediction/predictor.py +++ /dev/null @@ -1,122 +0,0 @@ -from itertools import chain -from operator import itemgetter -from typing import List, Dict, Iterable - -import numpy as np - -from image_prediction.config import CONFIG -from image_prediction.locations import MLRUNS_DIR, BASE_WEIGHTS -from image_prediction.utils import temporary_pdf_file, get_logger -from incl.redai_image.redai.redai.backend.model.model_handle import ModelHandle -from incl.redai_image.redai.redai.backend.pdf.image_extraction import extract_and_stitch -from incl.redai_image.redai.redai.utils.mlflow_reader import MlflowModelReader -from incl.redai_image.redai.redai.utils.shared import chunk_iterable - -logger = get_logger() - - -class Predictor: - """`ModelHandle` wrapper. Forwards to wrapped model handle for prediction and produces structured output that is - interpretable independently of the wrapped model (e.g. with regard to a .classes_ attribute). - """ - - def __init__(self, model_handle: ModelHandle = None): - """Initializes a ServiceEstimator. - - Args: - model_handle: ModelHandle object to forward to for prediction. By default, a model handle is loaded from the - mlflow database via CONFIG.service.run_id. - """ - try: - if model_handle is None: - reader = MlflowModelReader(run_id=CONFIG.service.run_id, mlruns_dir=MLRUNS_DIR) - self.model_handle = reader.get_model_handle(BASE_WEIGHTS) - else: - self.model_handle = model_handle - - self.classes = self.model_handle.model.classes_ - self.classes_readable = np.array(self.model_handle.classes) - self.classes_readable_aligned = self.classes_readable[self.classes[list(range(len(self.classes)))]] - except Exception as e: - logger.info(f"Service estimator initialization failed: {e}") - - def __make_predictions_human_readable(self, probs: np.ndarray) -> List[Dict[str, float]]: - """Translates an n x m matrix of probabilities over classes into an n-element list of mappings from classes to - probabilities. - - Args: - probs: probability matrix (items x classes) - - Returns: - list of mappings from classes to probabilities. - """ - classes = np.argmax(probs, axis=1) - classes = self.classes[classes] - classes_readable = [self.model_handle.classes[c] for c in classes] - return classes_readable - - def predict(self, images: List, probabilities: bool = False, **kwargs): - """Gathers predictions for list of images. Assigns each image a class and optionally a probability distribution - over all classes. - - Args: - images (List[PIL.Image]) : Images to gather predictions for. - probabilities: Whether to return dictionaries of the following form instead of strings: - { - "class": predicted class, - "probabilities": { - "class 1" : class 1 probability, - "class 2" : class 2 probability, - ... - } - } - - Returns: - By default the return value is a list of classes (meaningful class name strings). Alternatively a list of - dictionaries with an additional probability field for estimated class probabilities per image can be - returned. - """ - X = self.model_handle.prep_images(list(images)) - - probs_per_item = self.model_handle.model.predict_proba(X, **kwargs).astype(float) - classes = self.__make_predictions_human_readable(probs_per_item) - - class2prob_per_item = [dict(zip(self.classes_readable_aligned, probs)) for probs in probs_per_item] - class2prob_per_item = [ - dict(sorted(c2p.items(), key=itemgetter(1), reverse=True)) for c2p in class2prob_per_item - ] - - predictions = [{"class": c, "probabilities": c2p} for c, c2p in zip(classes, class2prob_per_item)] - - return predictions if probabilities else classes - - def predict_pdf(self, pdf, verbose=False): - with temporary_pdf_file(pdf) as pdf_path: - image_metadata_pairs = self.__extract_image_metadata_pairs(pdf_path, verbose=verbose) - return self.__predict_images(image_metadata_pairs) - - def __predict_images(self, image_metadata_pairs: Iterable, batch_size: int = CONFIG.service.batch_size): - def process_chunk(chunk): - images, metadata = zip(*chunk) - predictions = self.predict(images, probabilities=True) - return predictions, metadata - - def predict(image_metadata_pair_generator): - chunks = chunk_iterable(image_metadata_pair_generator, n=batch_size) - return map(chain.from_iterable, zip(*map(process_chunk, chunks))) - - try: - predictions, metadata = predict(image_metadata_pairs) - return predictions, metadata - - except ValueError: - return [], [] - - @staticmethod - def __extract_image_metadata_pairs(pdf_path: str, **kwargs): - def image_is_large_enough(metadata: dict): - x1, x2, y1, y2 = itemgetter("x1", "x2", "y1", "y2")(metadata) - - return abs(x1 - x2) > 2 and abs(y1 - y2) > 2 - - yield from extract_and_stitch(pdf_path, convert_to_rgb=True, filter_fn=image_is_large_enough, **kwargs) diff --git a/image_prediction/redai_adapter/__init__.py b/image_prediction/redai_adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/redai_adapter/efficient_net_wrapper.py b/image_prediction/redai_adapter/efficient_net_wrapper.py new file mode 100644 index 0000000..12a11b0 --- /dev/null +++ b/image_prediction/redai_adapter/efficient_net_wrapper.py @@ -0,0 +1,45 @@ +import tensorflow as tf + +from image_prediction.redai_adapter.model_wrapper import ModelWrapper + + +class EfficientNetWrapper(ModelWrapper): + def __init__(self, classes, base_weights_path=None, weights_path=None): + self.__input_shape = (224, 224, 3) + super().__init__(classes=classes, base_weights_path=base_weights_path, weights_path=weights_path) + + @property + def input_shape(self): + return self.__input_shape + + def _ModelWrapper__preprocess_tensor(self, tensor): + return tf.keras.applications.efficientnet.preprocess_input(tensor) + + def _ModelWrapper__build(self, base_weights=None) -> tf.keras.models.Model: + input_img = tf.keras.layers.Input(shape=self.input_shape) + + pretrained = tf.keras.applications.efficientnet.EfficientNetB0( + include_top=False, input_tensor=tf.keras.layers.Input(shape=self.input_shape), weights=base_weights + ) + + pretrained.trainable = False + + for layer in pretrained.layers: + layer.trainable = False + + pretrained = pretrained(input_img) + + finetuned = tf.keras.layers.Flatten()(pretrained) + finetuned = tf.keras.layers.Dense(512, activation="relu")(finetuned) + finetuned = tf.keras.layers.Dropout(0.2)(finetuned) + finetuned = tf.keras.layers.Dense(128, activation="relu")(finetuned) + finetuned = tf.keras.layers.Dropout(0.2)(finetuned) + finetuned = tf.keras.layers.Dense(32, activation="relu")(finetuned) + finetuned = tf.keras.layers.Dropout(0.2)(finetuned) + finetuned = tf.keras.layers.Dense(len(self.classes), activation="softmax")(finetuned) + + model = tf.keras.models.Model(inputs=input_img, outputs=finetuned) + + model.compile() + + return model diff --git a/image_prediction/redai_adapter/mlflow.py b/image_prediction/redai_adapter/mlflow.py new file mode 100644 index 0000000..7930ae3 --- /dev/null +++ b/image_prediction/redai_adapter/mlflow.py @@ -0,0 +1,72 @@ +import importlib +import json +import os +from functools import lru_cache + +import mlflow + +from image_prediction.redai_adapter.model import PredictionModelHandle + + +class MlflowModelReader: + def __init__(self, mlruns_dir=None): + self.mlruns_dir = mlruns_dir + mlflow.set_tracking_uri(self.mlruns_dir) + + @staticmethod + def __correct_artifact_uri(run_artifact_uri, base_path): + _, suffix = run_artifact_uri.split("mlruns/") + return os.path.join(base_path, suffix) + + def __get_weights_path(self, run_id, prefix="tt"): + run = self.__get_run(run_id) + + artifact_uri = self.__correct_artifact_uri(run.info.to_proto().artifact_uri, self.mlruns_dir) + path = os.path.join(artifact_uri, prefix, "train_dev", "estimator") + + base_path = os.path.join(path, "base_weights.h5") + weights_path = os.path.join(path, "weights.h5") + + return base_path, weights_path + + @lru_cache(maxsize=None) + def __get_run(self, run_id): + return mlflow.get_run(run_id) + + def __get_classes(self, run_id, prefix="tt"): + run = self.__get_run(run_id) + + classes = json.loads(run.data.params[os.path.join(prefix, "train_dev/estimator/classes")].replace("'", '"')) + + return classes + + def __get_model_handle(self, run_id): + run = self.__get_run(run_id) + + model_handle_builder = load_object(run.data.params["model_handle_builder"].strip()) + + base_weights_path, weights_path = self.__get_weights_path(run_id) + + model_handle = model_handle_builder( + self.__get_classes(run_id), base_weights_path=base_weights_path, weights_path=weights_path + ) + + return model_handle + + def __get_model(self, run_id) -> PredictionModelHandle: + model_handle = self.__get_model_handle(run_id) + model = PredictionModelHandle(model_handle) + return model + + def __getitem__(self, run_id): + return {"model": self.__get_model(run_id), "classes": self.__get_classes(run_id)} + + +def load_object(object_path): + path_fragments = object_path.split(".") + + module_path = ".".join(path_fragments[:-1]) + object_name = path_fragments[-1] + + module = importlib.import_module(module_path) + return getattr(module, object_name) diff --git a/image_prediction/redai_adapter/model.py b/image_prediction/redai_adapter/model.py new file mode 100644 index 0000000..c5450f5 --- /dev/null +++ b/image_prediction/redai_adapter/model.py @@ -0,0 +1,19 @@ +from funcy import rcompose + +from image_prediction.utils import get_logger + +logger = get_logger() + + +class PredictionModelHandle: + """Simplifies usage of ModelHandle instances for prediction purposes.""" + + def __init__(self, model_handle): + self.__predict = rcompose(model_handle.prep_images, model_handle.model.predict) + + def predict(self, *args, **kwargs): + return self.__predict(*args, **kwargs) + + def __call__(self, *args, **kwargs): + logger.debug("PredictionModelHandle.predict") + return self.predict(*args, **kwargs) diff --git a/image_prediction/redai_adapter/model_wrapper.py b/image_prediction/redai_adapter/model_wrapper.py new file mode 100644 index 0000000..776931e --- /dev/null +++ b/image_prediction/redai_adapter/model_wrapper.py @@ -0,0 +1,42 @@ +import abc + +import numpy as np +import tensorflow as tf + + +class ModelWrapper(abc.ABC): + def __init__(self, classes, base_weights_path=None, weights_path=None): + self.__classes = classes + self.model = self.__build(base_weights_path) + self.model.load_weights(weights_path) + + @property + @abc.abstractmethod + def input_shape(self): + raise NotImplementedError + + @property + def classes(self): + return self.__classes + + @abc.abstractmethod + def __preprocess_tensor(self, tensor): + raise NotImplementedError + + @staticmethod + def __images_to_tensor(images): + return np.array(list(map(tf.keras.preprocessing.image.img_to_array, images))) + + def __resize_and_convert(self, image): + return image.resize(self.input_shape[:-1]).convert("RGB") + + def prep_images(self, images): + images = map(self.__resize_and_convert, images) + tensor = self.__images_to_tensor(images) + tensor = self.__preprocess_tensor(tensor) + + return tensor + + @abc.abstractmethod + def __build(self, base_weights=None) -> tf.keras.models.Model: + raise NotImplementedError diff --git a/image_prediction/stitching/__init__.py b/image_prediction/stitching/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/stitching/grouping.py b/image_prediction/stitching/grouping.py new file mode 100644 index 0000000..0ff9283 --- /dev/null +++ b/image_prediction/stitching/grouping.py @@ -0,0 +1,63 @@ +from functools import lru_cache +from itertools import groupby + +import numpy as np +from funcy import compose, second + +from image_prediction.stitching.utils import make_coord_getter + + +class CoordGrouper: + def __init__(self, axis, tolerance=0): + self.c1_getter = make_coord_getter(f"{other_axis(axis)}1") + self.c2_getter = make_coord_getter(f"{other_axis(axis)}2") + self.tolerance = tolerance + + def group_pairs_by_lesser_coordinate(self, pairs): + return group_by_coordinate(pairs, self.c1_getter, self.tolerance) + + def group_pairs_by_greater_coordinate(self, pairs): + return group_by_coordinate(pairs, self.c2_getter, self.tolerance) + + +def other_axis(axis): + return "y" if axis == "x" else "x" + + +def fuzzify(func, tolerance): + def inner(item): + nonlocal mid_points + nonlocal lower_bounds + nonlocal upper_bounds + + value = func(item) + fits = (array(lower_bounds_array()) <= value) & (value <= array(upper_bounds_array())) + if any(fits): + return mid_points[np.argmax(fits)] + else: + mid_points = [*mid_points, value] + lower_bounds = [*lower_bounds, value - tolerance] + upper_bounds = [*upper_bounds, value + tolerance] + return value + + def lower_bounds_array(): + return tuple(lower_bounds) + + def upper_bounds_array(): + return tuple(upper_bounds) + + @lru_cache(maxsize=None) + def array(tpl): + return np.array(tpl) + + lower_bounds = [] + upper_bounds = [] + mid_points = [] + + return inner + + +def group_by_coordinate(pairs, coord_getter, tolerance=0): + coord_getter = fuzzify(coord_getter, tolerance) + pairs = sorted(pairs, key=coord_getter) + return map(compose(list, second), groupby(pairs, coord_getter)) diff --git a/image_prediction/stitching/merging.py b/image_prediction/stitching/merging.py new file mode 100644 index 0000000..2c3fbc4 --- /dev/null +++ b/image_prediction/stitching/merging.py @@ -0,0 +1,189 @@ +from copy import deepcopy +from functools import reduce +from typing import Iterable, Callable, List + +from PIL import Image +from funcy import juxt, first, rest, rcompose, rpartial, complement, ilen + +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.info import Info +from image_prediction.stitching.grouping import CoordGrouper +from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper +from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once, validate_box +from image_prediction.utils.generic import until + + +def make_merger_sentinel(): + def no_new_mergers(pairs): + nonlocal number_of_pairs_so_far + + number_of_pairs_now = len(pairs) + + if number_of_pairs_now == number_of_pairs_so_far: + return True + + else: + number_of_pairs_so_far = number_of_pairs_now + return False + + number_of_pairs_so_far = -1 + + return no_new_mergers + + +def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]: + pairs = merge_along_axis(pairs, "x", tolerance=tolerance) + pairs = list(merge_along_axis(pairs, "y", tolerance=tolerance)) + + return pairs + + +def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> Iterable[ImageMetadataPair]: + """Partially merges image-metadata pairs of adjacent images along a given axis. Needs to be iterated with + alternating axes until no more merges happen to merge all adjacent images. + + Explanation: + + Merging algorithm works as follows: + A dot represents a pair, a bracket a group and a colon a merged pair. + 1) Start with pairs: (........) + 2) Align on lesser: ([....] [....]) + 3) Align on greater: ([[..] [..]] [[....]]) + 4) Flatten once: ([..] [..] [....]) + 5) Merge orthogonally: ([:] [..] [:..]) + 6) Flatten once: (:..:..) + """ + + def group_pairs_within_groups_by_greater_coordinate(groups): + return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups) + + def merge_groups_along_orthogonal_axis(groups): + return map(rpartial(make_group_merger(axis), tolerance), groups) + + def group_pairs_by_lesser_coordinate(pairs): + return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs) + + return rcompose( + group_pairs_by_lesser_coordinate, + group_pairs_within_groups_by_greater_coordinate, + flatten_groups_once, + merge_groups_along_orthogonal_axis, + flatten_groups_once, + )(pairs) + + +def make_group_merger(axis): + return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis] + + +def merge_group_vertically(group: Iterable[ImageMetadataPair], tolerance=0): + return merge_group(group, "y", tolerance=tolerance) + + +def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0): + return merge_group(group, "x", tolerance=tolerance) + + +def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0): + reduce_group = make_merger_aggregator(direction, tolerance=tolerance) + no_new_mergers = make_merger_sentinel() + return until(no_new_mergers, reduce_group, group) + + +def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]: + """Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the + head H and aggregates non-adjacent in the tail T. + + Note: + When tolerance > 0, the bounding box of the merged image no longer matches the bounding box of the mereged + metadata. This is intended behaviour, but might be not be expected by the caller. + """ + + def merger_aggregator(pairs: Iterable[ImageMetadataPair]): + def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): + """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" + aggr, non_aggr = juxt(first, rest)(pairs_aggr) + if abs(c2_getter(aggr) - c1_getter(pair)) <= tolerance: + aggr = pair_merger(aggr, pair) + return aggr, *non_aggr + else: + return aggr, pair, *non_aggr + + # Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens + # only in c1 -> c2 direction. + pairs = sorted(pairs, key=c1_getter) + head_pair, pairs = juxt(first, rest)(pairs) + return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair])) + + assert tolerance >= 0 + + c1_getter = make_coord_getter(f"{axis}1") + c2_getter = make_coord_getter(f"{axis}2") + pair_merger = make_pair_merger(axis) + + return merger_aggregator + + +def make_pair_merger(axis): + return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis] + + +def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair): + metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata) + image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged) + return ImageMetadataPair(image_concatenated, metadata_merged) + + +def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): + metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) + image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged) + return ImageMetadataPair(image_concatenated, metadata_merged) + + +def merge_metadata_vertically(m1: dict, m2: dict): + m1, m2 = map(VerticalSplitMapper, [m1, m2]) + return merge_metadata(m1, m2) + + +def merge_metadata_horizontally(m1: dict, m2: dict): + m1, m2 = map(HorizontalSplitMapper, [m1, m2]) + return merge_metadata(m1, m2) + + +def merge_metadata(m1: dict, m2: dict): + + c1 = min(m1.c1, m2.c1) + c2 = max(m1.c2, m2.c2) + dim = abs(c2 - c1) + + merged = deepcopy(m1) + merged.dim = dim + merged.c1 = c1 + merged.c2 = c2 + + validate_box(merged.wrapped) + + return merged.wrapped + + +def concat_images_vertically(im1: Image, im2: Image, metadata: dict): + return concat_images(im1, im2, metadata, 1) + + +def concat_images_horizontally(im1: Image, im2: Image, metadata: dict): + return concat_images(im1, im2, metadata, 0) + + +def concat_images(im1: Image, im2: Image, metadata: dict, axis): + + im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT])) + + images = [im1, im2] + + offsets = 0, im1.size[axis], im_aggr.size[axis] - im2.size[axis] + + for im, offset in zip(images, offsets): + box = (offset, 0) if not axis else (0, offset) + im_aggr.paste(im, box=box) + + return im_aggr diff --git a/image_prediction/stitching/split_mapper.py b/image_prediction/stitching/split_mapper.py new file mode 100644 index 0000000..2863fa8 --- /dev/null +++ b/image_prediction/stitching/split_mapper.py @@ -0,0 +1,40 @@ +from copy import deepcopy +from dataclasses import field, dataclass +from operator import attrgetter + +from image_prediction.info import Info + + +@dataclass +class SplitMapper: + """Manages access into a mapping M by indirection through a specified access mapping to achieve a common + interface between various M_i. + """ + + __access_mapping: dict + wrapped: dict + __wrapped: dict = field(init=False) + + def __post_init__(self): + for k, v in self.__access_mapping.items(): + setattr(self, k, self.__wrapped[v]) + + @property + def wrapped(self): + ret = deepcopy(self.__wrapped) + ret.update(dict(zip(self.__access_mapping.values(), attrgetter(*self.__access_mapping.keys())(self)))) + return ret + + @wrapped.setter + def wrapped(self, wrapped): + self.__wrapped = wrapped + + +class HorizontalSplitMapper(SplitMapper): + def __init__(self, wrapped: dict): + super().__init__({"dim": Info.WIDTH, "c1": Info.X1, "c2": Info.X2}, wrapped) + + +class VerticalSplitMapper(SplitMapper): + def __init__(self, wrapped: dict): + super().__init__({"dim": Info.HEIGHT, "c1": Info.Y1, "c2": Info.Y2}, wrapped) diff --git a/image_prediction/stitching/stitching.py b/image_prediction/stitching/stitching.py new file mode 100644 index 0000000..9d98bd3 --- /dev/null +++ b/image_prediction/stitching/stitching.py @@ -0,0 +1,15 @@ +from typing import Iterable, List + +from funcy import rpartial + +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.stitching.merging import merge_along_both_axes, make_merger_sentinel +from image_prediction.utils.generic import until + + +def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]: + """Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent + images.""" + no_new_mergers = make_merger_sentinel() + merge = rpartial(merge_along_both_axes, tolerance) + return until(no_new_mergers, merge, pairs) diff --git a/image_prediction/stitching/utils.py b/image_prediction/stitching/utils.py new file mode 100644 index 0000000..e5bed7b --- /dev/null +++ b/image_prediction/stitching/utils.py @@ -0,0 +1,67 @@ +import json +from itertools import chain + +from image_prediction.exceptions import InvalidBox +from image_prediction.formatter.formatters.enum import EnumFormatter +from image_prediction.info import Info + + +def flatten_groups_once(groups): + return chain.from_iterable(groups) + + +def make_coord_getter(c): + return { + "x1": make_getter(Info.X1), + "x2": make_getter(Info.X2), + "y1": make_getter(Info.Y1), + "y2": make_getter(Info.Y2), + }[c] + + +def make_getter(key): + def getter(pair): + return pair.metadata[key] + + return getter + + +def make_length_getter(dim): + return { + "width": make_getter(Info.WIDTH), + "height": make_getter(Info.HEIGHT), + }[dim] + + +def validate_box(box): + validate_box_coords(box) + validate_box_size(box) + return box + + +def validate_box_coords(box): + + x_diff = box[Info.WIDTH] - (box[Info.X2] - box[Info.X1]) + y_diff = box[Info.HEIGHT] - (box[Info.Y2] - box[Info.Y1]) + + if x_diff: + raise InvalidBox(f"Width and x-coordinates differ by {x_diff} units: {format_box(box)}") + if y_diff: + raise InvalidBox(f"Width and y-coordinates differ by {y_diff} units: {format_box(box)}") + + return box + + +def validate_box_size(box): + + if not box[Info.WIDTH]: + raise InvalidBox(f"Zero width box: {format_box(box)}") + + if not box[Info.HEIGHT]: + raise InvalidBox(f"Zero height box: {format_box(box)}") + + return box + + +def format_box(box): + return json.dumps(EnumFormatter()(box), indent=2) diff --git a/image_prediction/transformer/__init__.py b/image_prediction/transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/transformer/transformer.py b/image_prediction/transformer/transformer.py new file mode 100644 index 0000000..e2dea6f --- /dev/null +++ b/image_prediction/transformer/transformer.py @@ -0,0 +1,20 @@ +import abc +from typing import Iterable + +from funcy import curry, identity + + +class Transformer(abc.ABC): + @abc.abstractmethod + def transform(self, obj): + raise NotImplementedError + + def __call__(self, obj): + return self._apply(self.transform, obj) + + @staticmethod + def _must_be_mapped_over(obj): + return isinstance(obj, Iterable) and not isinstance(obj, dict) + + def _apply(self, func, obj): + return (curry(map) if self._must_be_mapped_over(obj) else identity)(func)(obj) diff --git a/image_prediction/transformer/transformers/__init__.py b/image_prediction/transformer/transformers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/transformer/transformers/coordinate/__init__.py b/image_prediction/transformer/transformers/coordinate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/transformer/transformers/coordinate/coordinate_transformer.py b/image_prediction/transformer/transformers/coordinate/coordinate_transformer.py new file mode 100644 index 0000000..d72fc2a --- /dev/null +++ b/image_prediction/transformer/transformers/coordinate/coordinate_transformer.py @@ -0,0 +1,22 @@ +import abc + +from image_prediction.transformer.transformer import Transformer + + +class CoordinateTransformer(Transformer): + @abc.abstractmethod + def _forward(self, metadata): + raise NotImplementedError + + @abc.abstractmethod + def _backward(self, metadata): + raise NotImplementedError + + def forward(self, metadata): + return self._apply(self._forward, metadata) + + def backward(self, metadata): + return self._apply(self._backward, metadata) + + def transform(self, metadata): + return self.forward(metadata) diff --git a/image_prediction/transformer/transformers/coordinate/fitz.py b/image_prediction/transformer/transformers/coordinate/fitz.py new file mode 100644 index 0000000..87b854e --- /dev/null +++ b/image_prediction/transformer/transformers/coordinate/fitz.py @@ -0,0 +1,10 @@ +from image_prediction.transformer.transformers.coordinate.coordinate_transformer import CoordinateTransformer + + +class FitzCoordinateTransformer(CoordinateTransformer): + def _forward(self, metadata: dict): + """Fitz uses top left corner as origin; we take this as the reference coordinate system.""" + return metadata + + def _backward(self, metadata: dict): + return self.forward(metadata) diff --git a/image_prediction/transformer/transformers/coordinate/fpdf.py b/image_prediction/transformer/transformers/coordinate/fpdf.py new file mode 100644 index 0000000..7c5fad6 --- /dev/null +++ b/image_prediction/transformer/transformers/coordinate/fpdf.py @@ -0,0 +1,10 @@ +from image_prediction.transformer.transformers.coordinate.coordinate_transformer import CoordinateTransformer + + +class FPDFCoordinateTransformer(CoordinateTransformer): + def _forward(self, metadata: dict): + """FPDF uses top left corner as origin; we take this as the reference coordinate system.""" + return metadata + + def _backward(self, metadata: dict): + return self.forward(metadata) diff --git a/image_prediction/transformer/transformers/coordinate/pdfnet.py b/image_prediction/transformer/transformers/coordinate/pdfnet.py new file mode 100644 index 0000000..72dfeef --- /dev/null +++ b/image_prediction/transformer/transformers/coordinate/pdfnet.py @@ -0,0 +1,18 @@ +from operator import itemgetter + +from funcy import omit + +from image_prediction.info import Info +from image_prediction.transformer.transformers.coordinate.coordinate_transformer import CoordinateTransformer + + +class PDFNetCoordinateTransformer(CoordinateTransformer): + def _forward(self, metadata: dict): + """PDFNet coordinate system origin is in the bottom left corner.""" + y1, y2, page_height = itemgetter(Info.Y1, Info.Y2, Info.PAGE_HEIGHT)(metadata) + y1_t = page_height - y2 + y2_t = page_height - y1 + return {**omit(metadata, [Info.Y1, Info.Y2]), **{Info.Y1: y1_t, Info.Y2: y2_t}} + + def _backward(self, metadata: dict): + return self.forward(metadata) diff --git a/image_prediction/response.py b/image_prediction/transformer/transformers/response.py similarity index 68% rename from image_prediction/response.py rename to image_prediction/transformer/transformers/response.py index b5cdb7a..ca8ce99 100644 --- a/image_prediction/response.py +++ b/image_prediction/transformer/transformers/response.py @@ -1,28 +1,30 @@ -"""Defines functions for constructing service responses.""" - - import math -from itertools import starmap from operator import itemgetter from image_prediction.config import CONFIG +from image_prediction.transformer.transformer import Transformer +from image_prediction.utils import get_logger + +logger = get_logger() -def build_response(predictions: list, metadata: list) -> list: - return list(starmap(build_image_info, zip(predictions, metadata))) +class ResponseTransformer(Transformer): + def transform(self, data): + logger.debug("ResponseTransformer.transform") + return build_image_info(data) -def build_image_info(prediction: dict, metadata: dict) -> dict: +def build_image_info(data: dict) -> dict: def compute_geometric_quotient(): page_area_sqrt = math.sqrt(abs(page_width * page_height)) image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1)) return image_area_sqrt / page_area_sqrt - page_width, page_height, x1, x2, y1, y2, width, height = itemgetter( - "page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height" - )(metadata) + page_width, page_height, x1, x2, y1, y2, width, height, alpha = itemgetter( + "page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height", "alpha" + )(data) - quotient = compute_geometric_quotient() + quotient = round(compute_geometric_quotient(), 4) min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min) max_image_to_page_quotient_breached = bool(quotient > CONFIG.filters.image_to_page_quotient.max) @@ -33,14 +35,15 @@ def build_image_info(prediction: dict, metadata: dict) -> dict: width / height > CONFIG.filters.image_width_to_height_quotient.max ) - min_confidence_breached = bool(max(prediction["probabilities"].values()) < CONFIG.filters.min_confidence) - prediction["label"] = prediction.pop("class") # "class" as field name causes problem for Java objectmapper - prediction["probabilities"] = {klass: round(prob, 6) for klass, prob in prediction["probabilities"].items()} + classification = data["classification"] + + min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) image_info = { - "classification": prediction, - "position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": metadata["page_idx"] + 1}, + "classification": classification, + "position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1}, "geometry": {"width": width, "height": height}, + "alpha": alpha, "filters": { "geometry": { "imageSize": { @@ -49,7 +52,7 @@ def build_image_info(prediction: dict, metadata: dict) -> dict: "tooSmall": min_image_to_page_quotient_breached, }, "imageFormat": { - "quotient": width / height, + "quotient": round(width / height, 4), "tooTall": min_image_width_to_height_quotient_breached, "tooWide": max_image_width_to_height_quotient_breached, }, diff --git a/image_prediction/utils.py b/image_prediction/utils.py index 15badca..b28b04f 100644 --- a/image_prediction/utils.py +++ b/image_prediction/utils.py @@ -1,68 +1,3 @@ -import logging -import tempfile -from contextlib import contextmanager - -from image_prediction.config import CONFIG -@contextmanager -def temporary_pdf_file(pdf: bytes): - with tempfile.NamedTemporaryFile() as f: - f.write(pdf) - yield f.name - -def make_logger_getter(): - - logger = logging.getLogger("imclf") - logger.propagate = False - - handler = logging.StreamHandler() - handler.setLevel(CONFIG.service.logging_level) - - log_format = "[%(levelname)s]: %(message)s" - formatter = logging.Formatter(log_format) - - handler.setFormatter(formatter) - logger.addHandler(handler) - - def get_logger(): - return logger - - return get_logger - - -get_logger = make_logger_getter() - - -def show_banner(): - banner = ''' - ..... . ... .. - .d88888Neu. 'L xH88"`~ .x8X x .d88" oec : - F""""*8888888F .. . : :8888 .f"8888Hf 5888R @88888 - * `"*88*" .888: x888 x888. :8888> X8L ^""` '888R 8"*88% - -.... ue=:. ~`8888~'888X`?888f` X8888 X888h 888R 8b. - :88N ` X888 888X '888> 88888 !88888. 888R u888888> - 9888L X888 888X '888> 88888 %88888 888R 8888R - uzu. `8888L X888 888X '888> 88888 '> `8888> 888R 8888P -,""888i ?8888 X888 888X '888> `8888L % ?888 ! 888R *888> -4 9888L %888> "*88%""*88" '888!` `8888 `-*"" / .888B . 4888 -' '8888 '88% `~ " `"` "888. :" ^*888% '888 - "*8Nu.z*" `""***~"` "% 88R - 88> - 48 - '8 - ''' - - logger = logging.getLogger(__name__) - logger.propagate = False - - handler = logging.StreamHandler() - handler.setLevel(logging.INFO) - - formatter = logging.Formatter("") - - handler.setFormatter(formatter) - logger.addHandler(handler) - - logger.info(banner) diff --git a/image_prediction/utils/__init__.py b/image_prediction/utils/__init__.py new file mode 100644 index 0000000..e374e89 --- /dev/null +++ b/image_prediction/utils/__init__.py @@ -0,0 +1 @@ +from .logger import get_logger diff --git a/image_prediction/utils/banner.py b/image_prediction/utils/banner.py new file mode 100644 index 0000000..6a17d93 --- /dev/null +++ b/image_prediction/utils/banner.py @@ -0,0 +1,21 @@ +import logging + +from image_prediction.locations import BANNER_FILE + + +def show_banner(): + with open(BANNER_FILE) as f: + banner = "\n" + "".join(f.readlines()) + "\n" + + logger = logging.getLogger(__name__) + logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + + formatter = logging.Formatter("") + + handler.setFormatter(formatter) + logger.addHandler(handler) + + logger.info(banner) diff --git a/image_prediction/utils/generic.py b/image_prediction/utils/generic.py new file mode 100644 index 0000000..de71a5c --- /dev/null +++ b/image_prediction/utils/generic.py @@ -0,0 +1,15 @@ +from itertools import starmap + +from funcy import iterate, first, curry, map + + +def until(cond, func, *args, **kwargs): + return first(filter(cond, iterate(func, *args, **kwargs))) + + +def lift(fn): + return curry(map)(fn) + + +def starlift(fn): + return curry(starmap)(fn) diff --git a/image_prediction/utils/logger.py b/image_prediction/utils/logger.py new file mode 100644 index 0000000..58f6022 --- /dev/null +++ b/image_prediction/utils/logger.py @@ -0,0 +1,27 @@ +import logging + +from image_prediction.config import CONFIG + + +def make_logger_getter(): + logger = logging.getLogger("imclf") + logger.propagate = False + + handler = logging.StreamHandler() + handler.setLevel(CONFIG.service.logging_level) + + log_format = "%(asctime)s %(levelname)-8s %(message)s" + formatter = logging.Formatter(log_format, datefmt="%Y-%m-%d %H:%M:%S") + + handler.setFormatter(formatter) + logger.addHandler(handler) + + logger.setLevel(CONFIG.service.logging_level) + + def get_logger(): + return logger + + return get_logger + + +get_logger = make_logger_getter() diff --git a/image_prediction/utils/pdf_annotation.py b/image_prediction/utils/pdf_annotation.py new file mode 100644 index 0000000..43b8b12 --- /dev/null +++ b/image_prediction/utils/pdf_annotation.py @@ -0,0 +1,99 @@ +"""Defines utilities for PDF processing.""" + +import json +from operator import itemgetter + +from PDFNetPython3.PDFNetPython import ( + PDFDoc, + PDFNet, + Square, + Rect, + ColorPt, + BorderStyle, + SDFDoc, + Point, + Text, +) + +from image_prediction.utils import get_logger + +logger = get_logger() + + +def annotate_image(doc, image_info): + def draw_box(): + sq = Square.Create(doc.GetSDFDoc(), Rect(*coords)) + sq.SetColor(ColorPt(*color), 3) + sq.SetBorderStyle(BorderStyle(BorderStyle.e_dashed, 2, 0, 0, [4, 2])) + sq.SetPadding(4) + sq.RefreshAppearance() + page.AnnotPushBack(sq) + + def add_note(): + txt = Text.Create(doc.GetSDFDoc(), Point(*coords[:2])) + txt.SetContents(json.dumps(image_info, indent=2, ensure_ascii=False)) + txt.SetColor(ColorPt(*color)) + page.AnnotPushBack(txt) + txt.RefreshAppearance() + + red = (1, 0, 0) + green = (0, 1, 0) + blue = (0, 0, 1) + + if image_info["filters"]["allPassed"]: + color = green + elif image_info["filters"]["probability"]["unconfident"]: + color = red + else: + color = blue + + page = doc.GetPage(image_info["position"]["pageNumber"]) + coords = itemgetter("x1", "y1", "x2", "y2")(image_info["position"]) + + draw_box() + add_note() + + +def init(): + PDFNet.Initialize( + "Knecon AG(en.knecon.swiss):OEM:DDA-R::WL+:AMS(20211029):BECC974307DAB4F34B513BC9B2531B24496F6FCB83CD8AC574358A959730B622FABEF5C7" + ) + + +def draw_metadata_box(pdf_path, metadata, store_path): + + init() + + doc = PDFDoc(pdf_path) + + color = (1, 0, 0) + + print(metadata) + + coords = itemgetter("x1", "y1", "x2", "y2")(metadata) + page = doc.GetPage(1) + + sq = Square.Create(doc.GetSDFDoc(), Rect(*coords)) + sq.SetColor(ColorPt(*color), 3) + sq.SetBorderStyle(BorderStyle(BorderStyle.e_dashed, 2, 0, 0, [4, 2])) + sq.SetPadding(4) + sq.RefreshAppearance() + page.AnnotPushBack(sq) + + doc.Save(store_path, SDFDoc.e_linearized) + + logger.info(f"Saved annotated PDF to {store_path}") + + +def annotate_pdf(pdf_path, responses, store_path): + + init() + + doc = PDFDoc(pdf_path) + + for image_info in responses: + annotate_image(doc, image_info) + + doc.Save(store_path, SDFDoc.e_linearized) + + logger.info(f"Saved annotated PDF to {store_path}") diff --git a/incl/redai_image b/incl/redai_image deleted file mode 160000 index 4c3b26d..0000000 --- a/incl/redai_image +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4c3b26d7673457aaa99e0663dad6950cd36da967 diff --git a/pytest.ini b/pytest.ini index 5922a79..323d825 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,5 @@ [pytest] -norecursedirs = incl \ No newline at end of file +norecursedirs = incl +filterwarnings = + ignore:.*:DeprecationWarning + ignore:.*:DeprecationWarning diff --git a/requirements.txt b/requirements.txt index 217a846..8b33526 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,23 @@ -Flask==2.0.2 +Flask==2.1.1 requests==2.27.1 iteration-utilities==0.11.0 -dvc==2.9.3 +dvc==2.10.0 dvc[ssh] -frozendict==2.3.0 -waitress==2.0.0 -envyaml~=1.8.210417 +waitress==2.1.1 +envyaml==1.10.211231 dependency-check==0.6.* -envyaml~=1.8.210417 -mlflow~=1.20.2 -numpy~=1.19.3 -PDFNetPython3~=9.1.0 -tqdm~=4.62.2 -pandas~=1.3.1 -mlflow~=1.20.2 -tensorflow~=2.5.0 -PDFNetPython3~=9.1.0 -Pillow~=8.3.2 -PyYAML~=5.4.1 -scikit_learn~=0.24.2 - -pytest~=7.1.0 \ No newline at end of file +mlflow==1.24.0 +numpy==1.22.3 +tqdm==4.64.0 +pandas==1.4.2 +tensorflow==2.8.0 +PyYAML==6.0 +pytest~=7.1.0 +funcy==1.17 +PyMuPDF==1.19.6 +fpdf==1.7.2 +coverage==6.3.2 +Pillow==9.1.0 +PDFNetPython3==9.1.0 +pdf2image==1.16.0 +frozendict==2.3.0 diff --git a/run_tests.sh b/run_tests.sh new file mode 100755 index 0000000..655b8bb --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,15 @@ +echo "${bamboo_nexus_password}" | docker login --username "${bamboo_nexus_user}" --password-stdin nexus.iqser.com:5001 + +pip install dvc +pip install 'dvc[ssh]' +echo "Pulling dvc data" +dvc pull + +docker build -f Dockerfile_tests -t image-prediction-tests . + +rnd=$(date +"%s") +name=image-prediction-tests-${rnd} + +echo "running tests container" + +docker run --rm --name $name -v $PWD:$PWD -w $PWD -v /var/run/docker.sock:/var/run/docker.sock image-prediction-tests diff --git a/scripts/keras_MnWE.py b/scripts/keras_MnWE.py index 05a45dd..88b920a 100644 --- a/scripts/keras_MnWE.py +++ b/scripts/keras_MnWE.py @@ -40,7 +40,7 @@ def make_predict_fn(): model = make_model() def predict(*args): - # model = make_model() + # service_estimator = make_model() return model.predict(np.random.random(size=(1, 784))) return predict diff --git a/scripts/pyinfra_mock.py b/scripts/pyinfra_mock.py index fec12e9..07fddec 100644 --- a/scripts/pyinfra_mock.py +++ b/scripts/pyinfra_mock.py @@ -6,7 +6,7 @@ import requests def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--pdf_path", required=True) + parser.add_argument("pdf_path") args = parser.parse_args() return args diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py new file mode 100644 index 0000000..c2b4bb0 --- /dev/null +++ b/scripts/run_pipeline.py @@ -0,0 +1,55 @@ +import argparse +import json +import os +from glob import glob + +from image_prediction.pipeline import load_pipeline +from image_prediction.utils import get_logger +from image_prediction.utils.pdf_annotation import annotate_pdf + +logger = get_logger() + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("input", help="pdf file or directory") + parser.add_argument("--print", "-p", help="print output to terminal", action="store_true", default=False) + parser.add_argument("--page_interval", "-i", help="page interval [i, j), min index = 0", nargs=2, type=int) + + args = parser.parse_args() + + return args + + +def process_pdf(pipeline, pdf_path, page_range=None): + with open(pdf_path, "rb") as f: + logger.info(f"Processing {pdf_path}") + predictions = list(pipeline(f.read(), page_range=page_range)) + + annotate_pdf( + pdf_path, predictions, os.path.join("/tmp", os.path.basename(pdf_path.replace(".pdf", "_annotated.pdf"))) + ) + + return predictions + + +def main(args): + pipeline = load_pipeline(verbose=True, tolerance=3) + + if os.path.isfile(args.input): + pdf_paths = [args.input] + else: + pdf_paths = glob(os.path.join(args.input, "*.pdf")) + page_range = range(*args.page_interval) if args.page_interval else None + + for pdf_path in pdf_paths: + predictions = process_pdf(pipeline, pdf_path, page_range=page_range) + if args.print: + print(pdf_path) + print(json.dumps(predictions, indent=2)) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/setup/docker.sh b/setup/docker.sh deleted file mode 100755 index 7b4a837..0000000 --- a/setup/docker.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -set -e - -python3 -m venv build_venv -source build_venv/bin/activate -python3 -m pip install --upgrade pip - -pip install dvc -pip install 'dvc[ssh]' -dvc pull - -git submodule update --init --recursive - -docker build -f Dockerfile_base -t image-prediction-base . -docker build -f Dockerfile -t image-prediction . diff --git a/src/serve.py b/src/serve.py index 666ca80..005cbb2 100644 --- a/src/serve.py +++ b/src/serve.py @@ -4,45 +4,29 @@ from waitress import serve from image_prediction.config import CONFIG from image_prediction.flask import make_prediction_server -from image_prediction.predictor import Predictor -from image_prediction.response import build_response -from image_prediction.utils import get_logger, show_banner - -logger = get_logger() +from image_prediction.pipeline import load_pipeline +from image_prediction.utils import get_logger +from image_prediction.utils.banner import show_banner def main(): - def predict(pdf): - # Keras model.predict stalls when model was loaded in different process - # https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python - predictor = Predictor() - predictions, metadata = predictor.predict_pdf(pdf, verbose=CONFIG.service.progressbar) - response = build_response(predictions, metadata) - return response + logger = get_logger() - logger.info("Predictor ready.") + def predict(pdf): + # Keras service_estimator.predict stalls when service_estimator was loaded in different process; + # therefore, we re-load the model (part of the pipeline) every time we process a new document. + # https://stackoverflow.com/questions/42504669/keras-tensorflow-and-multiprocessing-in-python + logger.debug("Loading pipeline...") + pipeline = load_pipeline(verbose=CONFIG.service.verbose, batch_size=CONFIG.service.batch_size) + logger.debug("Running pipeline...") + return list(pipeline(pdf)) prediction_server = make_prediction_server(predict) - - run_prediction_server(prediction_server, mode=CONFIG.webserver.mode) - - -def run_prediction_server(app, mode="development"): - if mode == "development": - app.run(host=CONFIG.webserver.host, port=CONFIG.webserver.port, debug=True) - elif mode == "production": - serve(app, host=CONFIG.webserver.host, port=CONFIG.webserver.port) + serve(prediction_server, host=CONFIG.webserver.host, port=CONFIG.webserver.port, _quiet=False) if __name__ == "__main__": - logging_level = CONFIG.service.logging_level - logging.basicConfig(level=logging_level) - logging.getLogger("flask").setLevel(logging.ERROR) - logging.getLogger("urllib3").setLevel(logging.ERROR) - logging.getLogger("werkzeug").setLevel(logging.ERROR) - logging.getLogger("waitress").setLevel(logging.ERROR) - logging.getLogger("PIL").setLevel(logging.ERROR) - logging.getLogger("h5py").setLevel(logging.ERROR) + logging.basicConfig(level=CONFIG.service.logging_level) show_banner() diff --git a/test/conftest.py b/test/conftest.py index 71b37d1..ee2b3d5 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,70 +1,29 @@ -import os.path +import logging import pytest -from image_prediction.predictor import Predictor +from image_prediction.utils import get_logger -@pytest.fixture -def predictions(): - return [ - { - "class": "signature", - "probabilities": { - "signature": 1.0, - "logo": 9.150285377746546e-19, - "other": 4.374506412383356e-19, - "formula": 3.582569597002796e-24, - }, - } - ] +pytest_plugins = [ + "test.fixtures.extractor", + "test.fixtures.image", + "test.fixtures.image_metadata_pair", + "test.fixtures.input", + "test.fixtures.label", + "test.fixtures.metadata", + "test.fixtures.model", + "test.fixtures.model_store", + "test.fixtures.parameters", + "test.fixtures.pdf", + "test.fixtures.target", +] -@pytest.fixture -def metadata(): - return [ - { - "page_height": 612.0, - "page_width": 792.0, - "height": 61.049999999999955, - "width": 139.35000000000002, - "page_idx": 8, - "x1": 63.5, - "x2": 202.85000000000002, - "y1": 472.0, - "y2": 533.05, - } - ] - - -@pytest.fixture -def response(): - return [ - { - "classification": { - "label": "signature", - "probabilities": {"formula": 0.0, "logo": 0.0, "other": 0.0, "signature": 1.0}, - }, - "filters": { - "allPassed": True, - "geometry": { - "imageFormat": {"quotient": 2.282555282555285, "tooTall": False, "tooWide": False}, - "imageSize": {"quotient": 0.13248234868245012, "tooLarge": False, "tooSmall": False}, - }, - "probability": {"unconfident": False}, - }, - "geometry": {"height": 61.049999999999955, "width": 139.35000000000002}, - "position": {"pageNumber": 9, "x1": 63.5, "x2": 202.85000000000002, "y1": 472.0, "y2": 533.05}, - } - ] - - -@pytest.fixture -def predictor(): - return Predictor() - - -@pytest.fixture -def test_pdf(): - with open("./test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf", "rb") as f: - return f.read() +@pytest.fixture(autouse=True) +def mute_logger(): + logger = get_logger() + level = logger.level + logger.setLevel(logging.CRITICAL + 1) + yield + logger.setLevel(level) diff --git a/test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf similarity index 100% rename from test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf rename to test/data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf diff --git a/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json new file mode 100644 index 0000000..a2171bb --- /dev/null +++ b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json @@ -0,0 +1,43 @@ +[ + { + "classification": { + "label": "formula", + "probabilities": { + "formula": 1.0, + "logo": 0.0, + "other": 0.0, + "signature": 0.0 + } + }, + "position": { + "x1": 321, + "x2": 515, + "y1": 348, + "y2": 542, + "pageNumber": 2 + }, + "geometry": { + "width": 194, + "height": 194 + }, + "alpha": false, + "filters": { + "geometry": { + "imageSize": { + "quotient": 0.2741, + "tooLarge": false, + "tooSmall": false + }, + "imageFormat": { + "quotient": 1.0, + "tooTall": false, + "tooWide": false + } + }, + "probability": { + "unconfident": false + }, + "allPassed": true + } + } +] \ No newline at end of file diff --git a/test/data/stitching_with_tolerance.json b/test/data/stitching_with_tolerance.json new file mode 100644 index 0000000..f7f1049 --- /dev/null +++ b/test/data/stitching_with_tolerance.json @@ -0,0 +1,92 @@ +{ + "input": [ + { + "width": 100, + "height": 8, + "page_idx": 0, + "page_width": 100, + "page_height": 100, + "x1": 0, + "y1": 0, + "x2": 100, + "y2": 8 + }, + { + "width": 100, + "height": 9, + "page_idx": 0, + "page_width": 100, + "page_height": 100, + "x1": 0, + "y1": 9, + "x2": 100, + "y2": 18 + }, + { + "width": 100, + "height": 35, + "page_idx": 0, + "page_width": 100, + "page_height": 100, + "x1": 0, + "y1": 18, + "x2": 100, + "y2": 53 + }, + { + "width": 47, + "height": 46, + "page_idx": 0, + "page_width": 100, + "page_height": 100, + "x1": 0, + "y1": 54, + "x2": 47, + "y2": 100 + }, + { + "width": 31, + "height": 46, + "page_idx": 0, + "page_width": 100, + "page_height": 100, + "x1": 48, + "y1": 54, + "x2": 79, + "y2": 100 + }, + { + "width": 20, + "height": 19, + "page_idx": 0, + "page_width": 100, + "page_height": 100, + "x1": 80, + "y1": 54, + "x2": 100, + "y2": 73 + }, + { + "width": 20, + "height": 27, + "page_idx": 0, + "page_width": 100, + "page_height": 100, + "x1": 80, + "y1": 73, + "x2": 100, + "y2": 100 + } + ], + "target": { + "width": 100, + "height": 100, + "page_idx": 0, + "page_width": 100, + "page_height": 100, + "x1": 0, + "y1": 0, + "x2": 100, + "y2": 100 + } +} diff --git a/test/exploration_tests/funcy_test.py b/test/exploration_tests/funcy_test.py new file mode 100644 index 0000000..30c2cef --- /dev/null +++ b/test/exploration_tests/funcy_test.py @@ -0,0 +1,32 @@ +import pytest +from funcy import rcompose, chunks + + +def test_rcompose(): + f = rcompose(lambda x: x ** 2, str, lambda x: x * 2) + assert f(3) == "99" + + +def test_chunk_iterable_exact_split(): + a, b = chunks(5, iter(range(10))) + assert a == list(range(5)) + assert b == list(range(5, 10)) + + +def test_chunk_iterable_no_split(): + a = next(chunks(10, iter(range(10)))) + assert a == list(range(10)) + + +def test_chunk_iterable_last_partial(): + a, b, c, d = chunks(3, iter(range(10))) + assert d == [9] + + +def test_chunk_iterable_empty(): + with pytest.raises(StopIteration): + next(chunks(3, iter(range(0)))) + + +def test_chunk_iterable_less_than_chunk_size_elements(): + assert next(chunks(5, iter(range(2)))) == [0, 1] diff --git a/test/fixtures/__init__.py b/test/fixtures/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/fixtures/extractor.py b/test/fixtures/extractor.py new file mode 100644 index 0000000..8aa6db2 --- /dev/null +++ b/test/fixtures/extractor.py @@ -0,0 +1,17 @@ +import pytest + +from image_prediction.exceptions import UnknownImageExtractor +from image_prediction.image_extractor.extractors.mock import ImageExtractorMock +from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor + + +@pytest.fixture +def image_extractor(extractor_type): + if extractor_type == "mock": + return ImageExtractorMock() + elif extractor_type == "parsable_pdf": + return ParsablePDFImageExtractor() + elif extractor_type == "default": + return None + else: + raise UnknownImageExtractor(f"No image extractor for type {extractor_type} was specified.") diff --git a/test/fixtures/image.py b/test/fixtures/image.py new file mode 100644 index 0000000..fb1b7f3 --- /dev/null +++ b/test/fixtures/image.py @@ -0,0 +1,14 @@ +import pytest + +from test.utils.generation.image import array_to_image + + +@pytest.fixture +def images(input_batch): + return list(map(array_to_image, input_batch)) + + +@pytest.fixture +def input_size(alpha, __input_size): + w, h, d = __input_size + return w, h, d + alpha diff --git a/test/fixtures/image_metadata_pair.py b/test/fixtures/image_metadata_pair.py new file mode 100644 index 0000000..9b4b916 --- /dev/null +++ b/test/fixtures/image_metadata_pair.py @@ -0,0 +1,10 @@ +from itertools import starmap + +import pytest + +from image_prediction.image_extractor.extractor import ImageMetadataPair + + +@pytest.fixture +def image_metadata_pairs(images, metadata): + return list(starmap(ImageMetadataPair, zip(images, metadata))) diff --git a/test/fixtures/input.py b/test/fixtures/input.py new file mode 100644 index 0000000..b02f414 --- /dev/null +++ b/test/fixtures/input.py @@ -0,0 +1,7 @@ +import numpy as np +import pytest + + +@pytest.fixture +def input_batch(batch_size, input_size): + return np.random.random_sample(size=(batch_size, *input_size)) diff --git a/test/fixtures/label.py b/test/fixtures/label.py new file mode 100644 index 0000000..571ce23 --- /dev/null +++ b/test/fixtures/label.py @@ -0,0 +1,25 @@ +import pytest + +from image_prediction.exceptions import UnknownLabelFormat +from image_prediction.label_mapper.mappers.numeric import IndexMapper +from image_prediction.label_mapper.mappers.probability import ProbabilityMapper + + +@pytest.fixture +def label_mapper(label_format, classes): + if label_format == "index": + return IndexMapper(classes) + elif label_format == "probability": + return ProbabilityMapper(classes) + else: + raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") + + +@pytest.fixture(params=["index"]) +def label_format(request): + return request.param + + +@pytest.fixture +def classes(): + return ["A", "B", "C"] diff --git a/test/fixtures/metadata.py b/test/fixtures/metadata.py new file mode 100644 index 0000000..8440e24 --- /dev/null +++ b/test/fixtures/metadata.py @@ -0,0 +1,56 @@ +import random + +import pytest +from funcy import merge + +from image_prediction.info import Info +from test.utils.metadata import get_base_position_metadata + + +@pytest.fixture +def metadata(images): + page_idx = 0 + + def current_page_idx(): + nonlocal page_idx + page_idx += random.randint(0, 3) + return min(page_idx, len(images) - 1) + + def build_image_metadata(image): + width, height = image.size + page_width = 595 + page_height = 842 + x1 = random.randint(0, page_width - width) + x2 = x1 + width + y1 = random.randint(0, page_height - height) + y2 = y1 + height + metadata = { + Info.PAGE_WIDTH: page_width, + Info.PAGE_HEIGHT: page_height, + Info.PAGE_IDX: current_page_idx(), + Info.WIDTH: width, + Info.HEIGHT: height, + Info.X1: x1, + Info.X2: x2, + Info.Y1: y1, + Info.Y2: y2, + Info.ALPHA: image.mode == "RGBA", + } + return metadata + + return list(map(build_image_metadata, images)) + + +@pytest.fixture +def metadata_formatted(metadata): + def format_metadata(metadata): + return {key.value: val for key, val in metadata.items()} + + return list(map(format_metadata, metadata)) + + +@pytest.fixture +def base_patch_metadata(width, height, page_width, page_height): + metadata = get_base_position_metadata(width, height, page_width, page_height) + metadata = merge(metadata, {Info.X1: 0, Info.Y1: 0, Info.X2: width, Info.Y2: height}) + return metadata diff --git a/test/fixtures/model.py b/test/fixtures/model.py new file mode 100644 index 0000000..729d234 --- /dev/null +++ b/test/fixtures/model.py @@ -0,0 +1,113 @@ +import pytest + +from image_prediction.classifier.classifier import Classifier +from image_prediction.classifier.image_classifier import ImageClassifier +from image_prediction.estimator.adapter.adapter import EstimatorAdapter +from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor +from image_prediction.exceptions import UnknownEstimatorAdapter +from image_prediction.redai_adapter.model import PredictionModelHandle + + +@pytest.fixture +def estimator_mock(): + class EstimatorMock: + @staticmethod + def predict(batch): + return [None for _ in batch] + + @staticmethod + def predict_proba(batch): + return [None for _ in batch] + + def __call__(self, batch): + return self.predict(batch) + + return EstimatorMock() + + +@pytest.fixture +def image_classifier(classifier, monkeypatch, batch_of_expected_string_labels): + return ImageClassifier(classifier, preprocessor=BasicPreprocessor()) + + +@pytest.fixture +def classifier(estimator_adapter, label_mapper): + classifier = Classifier(estimator_adapter, label_mapper) + return classifier + + +@pytest.fixture +def estimator_adapter( + estimator_type, estimator_mock, keras_model, model_handle_mock, output_batch_generator, monkeypatch +): + if estimator_type == "mock": + estimator_adapter = EstimatorAdapter(estimator_mock) + elif estimator_type == "keras": + estimator_adapter = EstimatorAdapter(keras_model) + elif estimator_type == "redai": + estimator_adapter = EstimatorAdapter(PredictionModelHandle(model_handle_mock)) + else: + raise UnknownEstimatorAdapter(f"No adapter for estimator type {estimator_type} was specified.") + + def mock_predict(batch): + # Run real predict function to test for mechanical issues, but return externally defined + # predictions to test the callers of the estimator adapter against the expected predictions + return [next(output_batch_generator) for _ in _predict(batch)] + + _predict = estimator_adapter.predict + monkeypatch.setattr(estimator_adapter, "predict", mock_predict) + + return estimator_adapter + + +@pytest.fixture +def keras_model(input_size): + import os + + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + import tensorflow as tf + + tf.keras.backend.set_image_data_format("channels_last") + + inputs = tf.keras.Input(shape=input_size) + conv = tf.keras.layers.Conv2D(3, 3) + dense = tf.keras.layers.Dense(10) + + outputs = tf.keras.layers.Dense(10)(dense(conv(inputs))) + model = tf.keras.Model(inputs=inputs, outputs=outputs) + model.compile() + + return model + + +@pytest.fixture +def model(): + class Model: + @staticmethod + def predict(*args): + return True + + @staticmethod + def predict_proba(*args): + return True + + return Model() + + +@pytest.fixture +def model_handle_mock(estimator_mock): + class ModelHandleMock: + def __init__(self): + self.model = estimator_mock + + def prep_images(self, batch): + return [None for _ in batch] + + def predict(self, batch): + return [None for _ in batch] + + def predict_proba(self, batch): + return [None for _ in batch] + + return ModelHandleMock() diff --git a/test/fixtures/model_store.py b/test/fixtures/model_store.py new file mode 100644 index 0000000..49fd639 --- /dev/null +++ b/test/fixtures/model_store.py @@ -0,0 +1,61 @@ +import random +import string + +import pytest + +from image_prediction.exceptions import UnknownDatabaseType +from image_prediction.model_loader.database.connectors.mock import DatabaseConnectorMock +from image_prediction.model_loader.loader import ModelLoader +from image_prediction.model_loader.loaders.mlflow import MlflowConnector +from image_prediction.redai_adapter.mlflow import MlflowModelReader + + +@pytest.fixture +def model_database_record_identifier(): + return "".join(random.sample(string.ascii_letters, k=10)) + + +@pytest.fixture +def model_database_record(model, classes): + return {"model": model, "classes": classes} + + +@pytest.fixture +def model_database(model_database_record, model_database_record_identifier): + return {model_database_record_identifier: model_database_record} + + +@pytest.fixture +def database_connector(database_type, model_database, mlflow_reader): + if database_type == "mock": + return DatabaseConnectorMock(model_database) + + elif database_type == "mlflow": + return MlflowConnector(mlflow_reader) + + else: + raise UnknownDatabaseType(f"No connector for database type {database_type} was specified.") + + +@pytest.fixture +def model_loader(database_connector): + return ModelLoader(database_connector) + + +@pytest.fixture +def mlflow_run_id(): + from image_prediction.config import CONFIG + + return CONFIG.service.mlflow_run_id + + +@pytest.fixture +def mlruns_dir(): + from image_prediction.locations import MLRUNS_DIR + + return MLRUNS_DIR + + +@pytest.fixture +def mlflow_reader(mlruns_dir): + return MlflowModelReader(mlruns_dir) diff --git a/test/fixtures/parameters.py b/test/fixtures/parameters.py new file mode 100644 index 0000000..a837613 --- /dev/null +++ b/test/fixtures/parameters.py @@ -0,0 +1,38 @@ +from operator import itemgetter + +import pytest + + +@pytest.fixture(params=[220, 30]) +def page_height(request): + return request.param + + +@pytest.fixture(params=[100, 310]) +def page_width(request): + return request.param + + +@pytest.fixture(params=[False]) +def alpha(request): + return request.param + + +@pytest.fixture(params=[{"width": 10, "height": 15, "depth": 3}, {"width": 150, "height": 100, "depth": 3}]) +def __input_size(request): + return itemgetter("width", "height", "depth")(request.param) + + +@pytest.fixture(params=[33, 100]) +def height(request): + return request.param + + +@pytest.fixture(params=[10, 31]) +def width(request): + return request.param + + +@pytest.fixture(params=[0, 1, 2, 16, 32]) +def batch_size(request): + return request.param diff --git a/test/fixtures/pdf.py b/test/fixtures/pdf.py new file mode 100644 index 0000000..7353917 --- /dev/null +++ b/test/fixtures/pdf.py @@ -0,0 +1,23 @@ +import os + +import fpdf +import pytest + +from image_prediction.locations import TEST_DATA_DIR +from test.utils.generation.pdf import add_image, pdf_stream + + +@pytest.fixture +def pdf(image_metadata_pairs): + pdf = fpdf.FPDF(unit="pt") + + for pair in image_metadata_pairs: + add_image(pdf, pair) + + return pdf_stream(pdf) + + +@pytest.fixture +def real_pdf(): + with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f: + yield f.read() diff --git a/test/fixtures/target.py b/test/fixtures/target.py new file mode 100644 index 0000000..bcc0e75 --- /dev/null +++ b/test/fixtures/target.py @@ -0,0 +1,91 @@ +import json +import os +import random +from functools import partial +from operator import itemgetter + +import numpy as np +import pytest +from funcy import rcompose + +from image_prediction.exceptions import UnknownLabelFormat +from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys +from image_prediction.locations import TEST_DATA_DIR +from test.utils.label import map_labels + + +@pytest.fixture +def expected_predictions_mapped( + label_format, batch_of_expected_string_labels, batch_of_expected_label_to_probability_mappings +): + if label_format == "index": + return batch_of_expected_string_labels + elif label_format == "probability": + return batch_of_expected_label_to_probability_mappings + else: + raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") + + +@pytest.fixture +def expected_predictions(label_format, batch_of_expected_numeric_labels, batch_of_expected_probability_arrays): + if label_format == "index": + return batch_of_expected_numeric_labels + elif label_format == "probability": + return batch_of_expected_probability_arrays + else: + raise UnknownLabelFormat(f"No label mapper for label format {label_format} was specified.") + + +@pytest.fixture +def batch_of_expected_string_labels(batch_of_expected_numeric_labels, classes): + return map_labels(batch_of_expected_numeric_labels, classes) + + +@pytest.fixture +def batch_of_expected_numeric_labels(batch_size, classes): + return random.choices(range(len(classes)), k=batch_size) + + +@pytest.fixture +def batch_of_expected_label_to_probability_mappings(batch_of_expected_probability_arrays, classes): + def map_probabilities(probabilities): + lbl2prob = dict(sorted(zip(classes, map(rounder, probabilities)), key=itemgetter(1), reverse=True)) + most_likely = [*lbl2prob][0] + return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob} + + rounder = rcompose(partial(np.round, decimals=4), float) + return list(map(map_probabilities, batch_of_expected_probability_arrays)) + + +@pytest.fixture +def batch_of_expected_probability_arrays(batch_size, classes): + return [np.random.uniform(size=len(classes)) for _ in range(batch_size)] + + +@pytest.fixture +def output_batch_generator(expected_predictions): + return iter(expected_predictions) + + +@pytest.fixture +def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata): + return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)] + + +@pytest.fixture +def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted): + return [ + {"classification": epm, **mdt} + for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted) + ] + + +@pytest.fixture +def expected_predictions_mapped_and_formatted(expected_predictions_mapped): + return [{k.value: v for k, v in epm.items()} for epm in expected_predictions_mapped] + + +@pytest.fixture +def real_expected_service_response(): + with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f: + yield json.load(f) diff --git a/test/integration_tests/actual_server_test.py b/test/integration_tests/actual_server_test.py new file mode 100644 index 0000000..3e42e40 --- /dev/null +++ b/test/integration_tests/actual_server_test.py @@ -0,0 +1,103 @@ +import socket +from multiprocessing import Process + +import pytest +import requests +from funcy import retry +from waitress import serve + +from image_prediction.flask import make_prediction_server +from image_prediction.pipeline import load_pipeline + + +@pytest.fixture +def host(): + return "0.0.0.0" + + +def get_free_port(host): + sock = socket.socket() + sock.bind((host, 0)) + return sock.getsockname()[1] + + +@pytest.fixture +def port(host): + return get_free_port(host) + + +@pytest.fixture +def url(host, port): + return f"http://{host}:{port}" + + +@pytest.fixture(params=["dummy", "actual"]) +def server_type(request): + return request.param + + +@pytest.fixture +def server(server_type): + if server_type == "dummy": + return make_prediction_server(lambda x: int(x.decode()) // 2) + + elif server_type == "actual": + return make_prediction_server(lambda x: list(load_pipeline(verbose=False)(x))) + + else: + raise ValueError(f"Unknown server type {server_type}.") + + +@pytest.fixture +def host_and_port(host, port, server): + return {"host": host, "port": port} + + +@retry(tries=5, timeout=1) +def server_ready(url): + response = requests.get(f"{url}/ready") + response.raise_for_status() + return response.status_code == 200 + + +@pytest.fixture(autouse=True, scope="function") +def server_process(server, host_and_port, url): + def get_server_process(): + return Process(target=serve, kwargs={"app": server, **host_and_port}) + + server = get_server_process() + server.start() + + if server_ready(url): + yield + + server.kill() + server.join() + server.close() + + +@pytest.mark.parametrize("server_type", ["actual"]) +@pytest.mark.skip() +def test_server_predict(url, real_pdf, real_expected_service_response): + response = requests.post(f"{url}/predict", data=real_pdf) + response.raise_for_status() + assert response.json() == real_expected_service_response + + +@pytest.mark.parametrize("server_type", ["dummy"]) +def test_server_dummy_operation(url): + response = requests.post(f"{url}/predict", data=b"42") + response.raise_for_status() + assert response.json() == 21 + + +@pytest.mark.parametrize("server_type", ["dummy"]) +def test_server_health_check(url): + response = requests.get(f"{url}/health") + response.raise_for_status() + assert response.status_code == 200 + + +@pytest.mark.parametrize("server_type", ["dummy"]) +def test_server_ready_check(url): + assert server_ready(url) diff --git a/test/unit_tests/box_validation_test.py b/test/unit_tests/box_validation_test.py new file mode 100644 index 0000000..a9a0a07 --- /dev/null +++ b/test/unit_tests/box_validation_test.py @@ -0,0 +1,29 @@ +import pytest + +from image_prediction.exceptions import InvalidBox +from image_prediction.info import Info +from image_prediction.stitching.utils import validate_box_size, validate_box_coords + + +def test_validate_fail_too_short(): + box = {Info.WIDTH: 1, Info.HEIGHT: 0} + with pytest.raises(InvalidBox): + validate_box_size(box) + + +def test_validate_fail_too_thin(): + box = {Info.WIDTH: 0, Info.HEIGHT: 1} + with pytest.raises(InvalidBox): + validate_box_size(box) + + +def test_validate_fail_xs_width_mismatch(): + box = {Info.WIDTH: 2, Info.HEIGHT: 4, Info.X1: 0, Info.Y1: 0, Info.X2: 1, Info.Y2: 4} + with pytest.raises(InvalidBox): + validate_box_coords(box) + + +def test_validate_fail_ys_height_mismatch(): + box = {Info.WIDTH: 2, Info.HEIGHT: 3, Info.X1: 0, Info.Y1: 0, Info.X2: 2, Info.Y2: 4} + with pytest.raises(InvalidBox): + validate_box_coords(box) diff --git a/test/unit_tests/classifier_test.py b/test/unit_tests/classifier_test.py new file mode 100644 index 0000000..eaea50b --- /dev/null +++ b/test/unit_tests/classifier_test.py @@ -0,0 +1,19 @@ +import pytest + + +@pytest.mark.parametrize("estimator_type", ["mock", "keras", "redai"]) +@pytest.mark.parametrize("label_format", ["index", "probability"]) +def test_classifier(classifier, input_batch, expected_predictions_mapped): + predictions = list(classifier(input_batch)) + assert predictions == expected_predictions_mapped + + +def test_batch_format(input_batch): + def channels_are_last(input_batch): + return input_batch.shape[-1] == 3 + + def is_fourth_order_tensor(input_batch): + return input_batch.ndim == 4 + + assert channels_are_last(input_batch) + assert is_fourth_order_tensor(input_batch) diff --git a/test/unit_tests/compositor_test.py b/test/unit_tests/compositor_test.py new file mode 100644 index 0000000..1d59a2e --- /dev/null +++ b/test/unit_tests/compositor_test.py @@ -0,0 +1,33 @@ +import pytest + +from image_prediction.compositor.compositor import TransformerCompositor +from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter +from image_prediction.formatter.formatters.enum import EnumFormatter +from image_prediction.formatter.formatters.identity import IdentityFormatter +from image_prediction.info import Info +from test.utils.comparison import transform_equal + + +def test_identity(metadata): + compositor = TransformerCompositor(IdentityFormatter()) + assert transform_equal(compositor(metadata), metadata) + + +def test_composition(metadata, metadata_formatted): + compositor = TransformerCompositor(IdentityFormatter(), EnumFormatter()) + assert transform_equal(compositor(metadata), metadata_formatted) + + +@pytest.fixture() +def compositor_test_enum_metadata(): + return [{Info.WIDTH: 100, Info.PAGE_WIDTH: 200}] + + +@pytest.fixture() +def compositor_test_camel_case_metadata(): + return [{"width": 100, "pageWidth": 200}] + + +def test_enum_to_camel_case(compositor_test_enum_metadata, compositor_test_camel_case_metadata): + compositor = TransformerCompositor(EnumFormatter(), Snake2CamelCaseKeyFormatter()) + assert transform_equal(compositor(compositor_test_enum_metadata), compositor_test_camel_case_metadata) diff --git a/test/unit_tests/config_test.py b/test/unit_tests/config_test.py new file mode 100644 index 0000000..8c16cf8 --- /dev/null +++ b/test/unit_tests/config_test.py @@ -0,0 +1,38 @@ +import tempfile + +import pytest +import yaml + +from image_prediction.config import Config + + +@pytest.fixture +def config_file_content(): + return {"A": [{"B": [1, 2]}, {"C": 3}, 4], "D": {"E": {"F": True}}} + + +@pytest.fixture +def config(config_file_content): + with tempfile.NamedTemporaryFile(suffix=".yaml", mode="w") as f: + yaml.dump(config_file_content, f, default_flow_style=False) + yield Config(f.name) + + +def test_dot_access_key_exists(config): + assert config.A == [{"B": [1, 2]}, {"C": 3}, 4] + assert config.D.E["F"] + + +def test_access_key_exists(config): + assert config["A"] == [{"B": [1, 2]}, {"C": 3}, 4] + assert config["A"][0] == {"B": [1, 2]} + assert config["A"][0]["B"] == [1, 2] + assert config["A"][0]["B"][0] == 1 + + +def test_dot_access_key_does_not_exists(config): + assert config.B is None + + +def test_access_key_does_not_exists(config): + assert config["B"] is None diff --git a/test/unit_tests/coordinate_transformer_test.py b/test/unit_tests/coordinate_transformer_test.py new file mode 100644 index 0000000..1d39ca8 --- /dev/null +++ b/test/unit_tests/coordinate_transformer_test.py @@ -0,0 +1,239 @@ +from operator import itemgetter, attrgetter + +import numpy as np +import pytest +from fpdf import fpdf +from funcy import compose +from pdf2image import pdf2image + +from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.info import Info +from image_prediction.transformer.transformers.coordinate.fitz import FitzCoordinateTransformer +from image_prediction.transformer.transformers.coordinate.fpdf import FPDFCoordinateTransformer +from image_prediction.transformer.transformers.coordinate.pdfnet import PDFNetCoordinateTransformer +from test.utils.metadata import get_base_position_metadata +from test.utils.generation.image import array_to_image +from test.utils.generation.pdf import add_image +from test.utils.comparison import transform_equal + + +@pytest.mark.parametrize("coordinate_system", ["fpdf"]) +@pytest.mark.skip(reason="No content") +def test_fpdf_coordinate_transformer(position_metadata_in_given_system, position_metadata_in_reference_system): + """We use FPDF's coordinate system as the reference system (arbitrarily). Hence, FPDFCoordinateTransformer + actually does not do anything. This test merely documents the fact, that FPDF is the reference system. + """ + pass + + +@pytest.mark.parametrize("coordinate_system", ["fitz"]) +@pytest.mark.skip(reason="No content") +def test_fitz_coordinate_transformer(position_metadata_in_given_system, position_metadata_in_reference_system): + """How I inferred the transformation: + + - extract images from coordinate_test_fpdf with ParsablePDFImageExtractor (see test_parsable_pdf_image_extractor) + - Compare position_metadata_in_given_system (fitz) with position_metadata_in_reference_system (fpdf) + - Observe that they are identical + """ + pass + + +@pytest.mark.parametrize("coordinate_system", ["pdfnet"]) +@pytest.mark.skip(reason="No content") +def test_pdfnet_coordinate_transformer(position_metadata_in_given_system, position_metadata_in_reference_system): + """How I inferred the transformation: + + - save coordinate_test_fpdf to disk as file f + - draw boxes for position_metadata_in_reference_system in f with draw_metadata_box + - save annotated pdf as file g + - look at discrepancy between the black square and the red box in g + """ + pass + + +@pytest.mark.parametrize("coordinate_system", ["fpdf", "fitz", "pdfnet"]) +def test_coordinate_transformer_by_metadata( + transformer, position_metadata_in_given_system, position_metadata_in_reference_system +): + assert transform_equal( + transformer.forward(position_metadata_in_reference_system), position_metadata_in_given_system + ) + assert transform_equal( + transformer.backward(position_metadata_in_given_system), position_metadata_in_reference_system + ) + assert transform_equal( + compose(transformer.backward, transformer.forward)(position_metadata_in_reference_system), + position_metadata_in_reference_system, + ) + + +@pytest.fixture +def transformer(coordinate_system): + if coordinate_system == "fpdf": + return FPDFCoordinateTransformer() + + elif coordinate_system == "fitz": + return FitzCoordinateTransformer() + + elif coordinate_system == "pdfnet": + return PDFNetCoordinateTransformer() + + else: + raise ValueError(f"Unknown coordinate system: {coordinate_system}") + + +@pytest.fixture +def position_metadata_in_given_system(corner, corner2metadata_in_given_system, multiple): + metadata = corner2metadata_in_given_system[corner] + return [metadata, metadata] if multiple else metadata + + +@pytest.fixture +def position_metadata_in_reference_system(corner, corner2metadata_in_reference_system, multiple): + metadata = corner2metadata_in_reference_system[corner] + return [metadata, metadata] if multiple else metadata + + +@pytest.fixture(params=["top_left", "bottom_left", "bottom_right", "top_right"]) +def corner(request): + return request.param + + +@pytest.fixture +def corner2metadata_in_given_system( + coordinate_system, get_fpdf_corner_metadat, get_fitz_corner_metadat, get_pdfnet_corner_metadata +): + if coordinate_system == "fpdf": + return get_fpdf_corner_metadat + + elif coordinate_system == "fitz": + return get_fitz_corner_metadat + + elif coordinate_system == "pdfnet": + return get_pdfnet_corner_metadata + + else: + raise ValueError(f"Unknown coordinate system: {coordinate_system}") + + +@pytest.fixture +def corner2metadata_in_reference_system(get_fpdf_corner_metadat): + return get_fpdf_corner_metadat + + +@pytest.fixture(params=[True, False]) +def multiple(request): + return request.param + + +@pytest.fixture +def get_fpdf_corner_metadat(base_position_metadata, get_metadata_for_coords, get_image_and_page_edge_lengths): + """Origin top left, y1 <= y2; all coords on page are positive + (0,0)--+--(2,0)--+ + |////| |////| + +--(1,1) +--(3,1) + + (0,2)--+ (2,2)--+ + |////| |////| + +--(1,3) +--(3,3) + """ + # noinspection PyTupleAssignmentBalance + width, height, page_width, page_height = get_image_and_page_edge_lengths() + + return { + "top_left": get_metadata_for_coords(0, 0, width, height), + "bottom_left": get_metadata_for_coords(0, page_height - height, width, page_height), + "bottom_right": get_metadata_for_coords(page_width - width, page_height - height, page_width, page_height), + "top_right": get_metadata_for_coords(page_width - width, 0, page_width, height), + } + + +@pytest.fixture +def get_fitz_corner_metadat(get_fpdf_corner_metadat): + return get_fpdf_corner_metadat + + +@pytest.fixture +def get_pdfnet_corner_metadata(base_position_metadata, get_metadata_for_coords, get_image_and_page_edge_lengths): + """Origin bottom left, y1 <= y2; all coords on page are positive + +---(1,3) +--(3,3) + |////| |////| + (0,2)--+ (2,2)--+ + + +--(1,1) +--(3,1) + |////| |////| + (0,0)--+ (2,0)--+ + """ + # noinspection PyTupleAssignmentBalance + width, height, page_width, page_height = get_image_and_page_edge_lengths() + + return { + "top_left": get_metadata_for_coords(0, page_height - height, width, page_height), + "bottom_left": get_metadata_for_coords(0, 0, width, height), + "bottom_right": get_metadata_for_coords(page_width - width, 0, page_width, height), + "top_right": get_metadata_for_coords(page_width - width, page_height - height, page_width, page_height), + } + + +@pytest.fixture +def base_position_metadata(width, height, page_width, page_height): + return get_base_position_metadata(width, height, page_width, page_height) + + +@pytest.fixture +def get_metadata_for_coords(base_position_metadata): + def __get_metadata_for_coords(*coords): + meta_data_coords = get_metadata_coords(*coords) + return {**meta_data_coords, **base_position_metadata} + + return __get_metadata_for_coords + + +@pytest.fixture +def get_image_and_page_edge_lengths(base_position_metadata): + def __get_w_h_pw_ph(): + return itemgetter(*attrgetter("WIDTH", "HEIGHT", "PAGE_WIDTH", "PAGE_HEIGHT")(Info))(base_position_metadata) + + return __get_w_h_pw_ph + + +def get_metadata_coords(x1, y1, x2, y2): + return {Info.X1: x1, Info.Y1: y1, Info.X2: x2, Info.Y2: y2} + + +@pytest.mark.parametrize("coordinate_system", ["pdfnet"]) +@pytest.mark.parametrize("multiple", [False]) +def test_coordinate_transformer_by_image( + transformer, position_metadata_in_given_system, position_metadata_in_reference_system +): + metadata_transformed = transformer(position_metadata_in_given_system) + + target_image = metadata_to_test_page_image(position_metadata_in_reference_system) + test_image = metadata_to_test_page_image(metadata_transformed) + + assert np.allclose(target_image, test_image) + + +def metadata_to_test_page_image(metadata): + image = get_coordinate_test_image(*itemgetter(*attrgetter("WIDTH", "HEIGHT")(Info))(metadata)) + pdf = get_coordinate_test_fpdf(*itemgetter(*attrgetter("PAGE_WIDTH", "PAGE_HEIGHT")(Info))(metadata)) + add_image(pdf, ImageMetadataPair(image, metadata)) + page_image = fpdf_to_page_tensor(pdf) + return page_image + + +def get_coordinate_test_image(width, height): + return array_to_image(np.zeros(shape=(width, height, 3))) + + +def get_coordinate_test_fpdf(page_width, page_height): + pdf = fpdf.FPDF(unit="pt", format=(page_width, page_height)) + return pdf + + +def fpdf_to_page_tensor(fpdf): + pdf = fpdf.output(dest="S").encode("latin1") + page_image = pdf2image.convert_from_bytes(pdf)[0] + tensor = image_to_normalized_tensor(page_image) + return tensor diff --git a/test/unit_tests/formatter_test.py b/test/unit_tests/formatter_test.py new file mode 100644 index 0000000..b012b9a --- /dev/null +++ b/test/unit_tests/formatter_test.py @@ -0,0 +1,27 @@ +import pytest + +from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter +from image_prediction.formatter.formatters.enum import EnumFormatter + + +def test_enum_formatter(metadata, metadata_formatted): + assert list(EnumFormatter()(metadata)) == metadata_formatted + + +@pytest.mark.parametrize("label_format", ["probability"]) +def test_enum_formatter(metadata_plus_mapped_prediction, metadata_formatted_plus_mapped_prediction_formatted): + assert list(EnumFormatter()(metadata_plus_mapped_prediction)) == metadata_formatted_plus_mapped_prediction_formatted + + +def test_camel_case_key_formatter(snake_case_data, camel_case_data): + assert Snake2CamelCaseKeyFormatter()(snake_case_data) == camel_case_data + + +@pytest.fixture +def snake_case_data(): + return {"a_key": {"key": None, "key_2": ["may_not_be_changed", (1, 2, 2.2)]}, 2: {"yet_another_key": 3, 4: "a"}} + + +@pytest.fixture +def camel_case_data(): + return {"aKey": {"key": None, "key2": ["may_not_be_changed", (1, 2, 2.2)]}, 2: {"yetAnotherKey": 3, 4: "a"}} diff --git a/test/unit_tests/image_classifier_test.py b/test/unit_tests/image_classifier_test.py new file mode 100644 index 0000000..5cffdf3 --- /dev/null +++ b/test/unit_tests/image_classifier_test.py @@ -0,0 +1,7 @@ +import pytest + + +@pytest.mark.parametrize("estimator_type", ["mock", "keras"]) +def test_predict(image_classifier, images, batch_of_expected_string_labels): + predictions = list(image_classifier.predict(images)) + assert predictions == batch_of_expected_string_labels diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py new file mode 100644 index 0000000..e52b2b5 --- /dev/null +++ b/test/unit_tests/image_extractor_test.py @@ -0,0 +1,77 @@ +import random +from operator import itemgetter + +import fitz +import fpdf +import pytest +from PIL import Image +from funcy import first, rest + +from image_prediction.extraction import extract_images_from_pdf +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel +from image_prediction.info import Info +from test.utils.comparison import metadata_equal, image_sets_equal +from test.utils.generation.pdf import add_image, pdf_stream + + +@pytest.mark.parametrize("extractor_type", ["mock"]) +@pytest.mark.parametrize("batch_size", [1, 2, 16]) +def test_image_extractor_mock(image_extractor, images): + images_extracted, metadata = map(list, zip(*image_extractor(images))) + assert images_extracted == images + + +@pytest.mark.parametrize("extractor_type", ["parsable_pdf", "default"]) +@pytest.mark.parametrize("input_size", [{"depth": 3, "width": 170, "height": 220}], indirect=["input_size"]) +@pytest.mark.parametrize("alpha", [False, True]) +def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha): + images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor)) + if not alpha: + assert image_sets_equal(images_extracted, images) + assert metadata_equal(metadata_extracted, metadata) + + +@pytest.mark.parametrize("batch_size", [1, 2, 16]) +def test_extract_pages(pdf): + doc = fitz.Document(stream=pdf) + + max_index = max(0, doc.page_count - 1) + i = random.randint(0, max(0, max_index - 1)) + j = random.randint(i + 1, max_index) if max_index > 0 else 0 + + page_range = range(i, j) + + pages = list(extract_pages(doc, page_range)) + assert all((isinstance(p, fitz.Page) for p in pages)) + assert len(pages) == len(page_range) + + +@pytest.mark.parametrize("suffix", ["gif", "png", "jpeg"]) +@pytest.mark.parametrize("mode", ["RGB", "RGBA"]) +def test_has_alpha_channel(base_patch_metadata, suffix, mode): + + mode = "RGB" if suffix == "jpeg" else mode + + pdf = fpdf.FPDF(unit="pt") + + image = Image.new(mode, itemgetter(Info.WIDTH, Info.HEIGHT)(base_patch_metadata), color=(10, 10, 10)) + + add_image(pdf, ImageMetadataPair(image, base_patch_metadata), suffix=suffix) + + doc = fitz.Document(stream=pdf_stream(pdf)) + + page = first(doc) + + xrefs = map(itemgetter("xref"), get_image_infos(page)) + + result = has_alpha_channel(doc, first(xrefs)) + + if mode == "RGBA": + assert result + if mode == "RGB": + assert not result + + assert not list(rest(xrefs)) + + doc.close() diff --git a/test/unit_tests/image_stitching_test.py b/test/unit_tests/image_stitching_test.py new file mode 100644 index 0000000..edf7923 --- /dev/null +++ b/test/unit_tests/image_stitching_test.py @@ -0,0 +1,246 @@ +import json +import os +from copy import deepcopy +from functools import partial +from itertools import starmap, repeat +from operator import itemgetter +from typing import List + +import fpdf +import pdf2image +import pytest +from funcy import juxt, one, first + +from image_prediction.formatter.formatters.enum import ReverseEnumFormatter +from image_prediction.image_extractor.extractor import ImageMetadataPair +from image_prediction.info import Info +from image_prediction.stitching.grouping import group_by_coordinate +from image_prediction.stitching.merging import ( + merge_metadata_horizontally, + merge_metadata_vertically, + merge_pair_horizontally, + merge_pair_vertically, + concat_images_horizontally, + concat_images_vertically, + merge_group_horizontally, + merge_group_vertically, +) +from image_prediction.stitching.stitching import stitch_pairs +from image_prediction.stitching.utils import ( + make_coord_getter, + make_length_getter, +) +from test.utils.comparison import images_equal +from test.utils.generation.image import random_single_color_image_from_metadata, gray_image_from_metadata +from test.utils.generation.pdf import add_image +from test.utils.stitching import BoxSplitter + +x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2")) +width_getter, height_getter = map(make_length_getter, ("width", "height")) + + +def test_group_by_coordinate_exact(): + pairs = [(0, 1), (0, 3), (1, 4), (1, 4), (1, 2), (3, 3)] + pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=0)) + assert pairs_grouped == [[(0, 1), (0, 3)], [(1, 4), (1, 4), (1, 2)], [(3, 3)]] + + +def test_group_by_coordinate_fuzzy(): + pairs = [(0, 1), (1, 3), (1, 4), (2, 4), (2, 2), (3, 3)] + pairs_grouped = list(group_by_coordinate(pairs, itemgetter(0), tolerance=1)) + assert pairs_grouped == [[(0, 1), (1, 3), (1, 4)], [(2, 4), (2, 2), (3, 3)]] + + +def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image): + pairs_stitched = stitch_pairs(patch_image_metadata_pairs) + pair_stitched = first(pairs_stitched) + + assert len(pairs_stitched) == 1 + assert pair_stitched.metadata == base_patch_metadata + assert images_equal(pair_stitched.image.resize((10, 10)), base_patch_image.resize((10, 10)), atol=0.4) + + +def test_image_stitcher_with_gaps_must_succeed(): + from image_prediction.locations import TEST_DATA_DIR + + with open(os.path.join(TEST_DATA_DIR, "stitching_with_tolerance.json")) as f: + patches_metadata, base_patch_metadata = itemgetter("input", "target")(ReverseEnumFormatter(Info)(json.load(f))) + + images = map(gray_image_from_metadata, patches_metadata) + patch_image_metadata_pairs = list(starmap(ImageMetadataPair, zip(images, patches_metadata))) + + pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=7) + + assert len(pairs_stitched) == 1 + pair_stitched = first(pairs_stitched) + assert pair_stitched.metadata == base_patch_metadata + + +@pytest.mark.parametrize("noise", [(0, 2)]) +@pytest.mark.parametrize("split_count", [5]) +@pytest.mark.parametrize("width", [100]) +@pytest.mark.parametrize("height", [100]) +@pytest.mark.parametrize("page_width", [100]) +@pytest.mark.parametrize("page_height", [100]) +@pytest.mark.parametrize("execution_number", range(100)) +@pytest.mark.xfail(reason="Does not always succeed due to locally maximizing merging logic.") +def test_image_stitcher_with_gaps_can_fail(patch_image_metadata_pairs, base_patch_metadata, execution_number): + pairs_stitched = stitch_pairs(patch_image_metadata_pairs, tolerance=4) + assert len(pairs_stitched) == 1 and first(pairs_stitched).metadata == base_patch_metadata + + +def test_merge_group_horizontally(horizontal_merge_test_pairs): + pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs + + prs_merged = merge_group_horizontally([pr1, pr2]) + assert len(prs_merged) == 1 + assert pair_equal(prs_merged[0], pr_merged_expected) + + mdat3 = deepcopy(pr2.metadata) + mdat3[Info.HEIGHT] += 30 + mdat3[Info.Y2] += 30 + im3 = gray_image_from_metadata(mdat3) + pr3 = ImageMetadataPair(im3, mdat3) + + prs_merged = merge_group_horizontally([pr1, pr2, pr3]) + assert len(prs_merged) == 2 + assert one(partial(pair_equal, pr_merged_expected), prs_merged) + + +def test_merge_group_vertically(vertical_merge_test_pairs): + pr1, pr2, pr_merged_expected = vertical_merge_test_pairs + + prs_merged = merge_group_vertically([pr1, pr2]) + assert len(prs_merged) == 1 + assert pair_equal(prs_merged[0], pr_merged_expected) + + mdat3 = deepcopy(pr2.metadata) + mdat3[Info.WIDTH] += 30 + mdat3[Info.X2] += 30 + im3 = gray_image_from_metadata(mdat3) + pr3 = ImageMetadataPair(im3, mdat3) + + prs_merged = merge_group_vertically([pr1, pr2, pr3]) + assert len(prs_merged) == 2 + assert one(partial(pair_equal, pr_merged_expected), prs_merged) + + +def pair_equal(pr1, pr2): + return pr1.metadata == pr2.metadata and images_equal(pr1.image, pr2.image) + + +def test_merge_pairs_horizontally(horizontal_merge_test_pairs): + pr1, pr2, pr_merged_expected = horizontal_merge_test_pairs + pr_merged = merge_pair_horizontally(pr1, pr2) + assert pair_equal(pr_merged, pr_merged_expected) + + +def test_merge_pairs_vertically(vertical_merge_test_pairs): + pr1, pr2, pr_merged_expected = vertical_merge_test_pairs + pr_merged = merge_pair_vertically(pr1, pr2) + assert pair_equal(pr_merged, pr_merged_expected) + + +@pytest.fixture +def horizontal_merge_test_pairs(horizontal_merge_test_metadata): + images = map(gray_image_from_metadata, horizontal_merge_test_metadata) + return list(starmap(ImageMetadataPair, zip(images, horizontal_merge_test_metadata))) + + +@pytest.fixture +def vertical_merge_test_pairs(vertical_merge_test_metadata): + images = map(gray_image_from_metadata, vertical_merge_test_metadata) + return list(starmap(ImageMetadataPair, zip(images, vertical_merge_test_metadata))) + + +def test_merge_metadata_horizontally(horizontal_merge_test_metadata): + mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata + assert merge_metadata_horizontally(mdat1, mdat2) == mdat_merged + + +def test_merge_metadata_vertically(vertical_merge_test_metadata): + mdat1, mdat2, mdat_merged = vertical_merge_test_metadata + assert merge_metadata_vertically(mdat1, mdat2) == mdat_merged + + +@pytest.fixture +def horizontal_merge_test_metadata(merge_test_metadata): + mdat1, mdat2, mdat_merged = merge_test_metadata + + mdat2[Info.X1] = mdat1[Info.X2] + mdat2[Info.X2] = mdat2[Info.X1] + mdat2[Info.WIDTH] + + mdat_merged.update({Info.WIDTH: mdat1[Info.WIDTH] + mdat2[Info.WIDTH], Info.X2: mdat2[Info.X2]}) + + return mdat1, mdat2, mdat_merged + + +@pytest.fixture +def vertical_merge_test_metadata(merge_test_metadata): + mdat1, mdat2, mdat_merged = merge_test_metadata + + mdat2[Info.Y1] = mdat1[Info.Y2] + mdat2[Info.Y2] = mdat2[Info.Y1] + mdat2[Info.HEIGHT] + + mdat_merged.update({Info.HEIGHT: mdat1[Info.HEIGHT] + mdat2[Info.HEIGHT], Info.Y2: mdat2[Info.Y2]}) + + return mdat1, mdat2, mdat_merged + + +@pytest.fixture +def merge_test_metadata(base_patch_metadata): + return juxt(*repeat(deepcopy, 3))(base_patch_metadata) + + +@pytest.fixture +def base_patch_image(stitch_test_pdf): + return pdf2image.convert_from_bytes(stitch_test_pdf)[0] + + +def test_concat_images_horizontally(horizontal_merge_test_metadata): + mdat1, mdat2, mdat_merged = horizontal_merge_test_metadata + im1, im2, im_merged_expected = map(gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) + im_merged = concat_images_horizontally(im1, im2, mdat_merged) + assert im_merged.size == im_merged_expected.size + assert images_equal(im_merged, im_merged_expected) + + +def test_concat_images_vertically(vertical_merge_test_metadata): + mdat1, mdat2, mdat_merged = vertical_merge_test_metadata + im1, im2, im_merged_expected = map(gray_image_from_metadata, [mdat1, mdat2, mdat_merged]) + im_merged = concat_images_vertically(im1, im2, mdat_merged) + assert im_merged.size == im_merged_expected.size + assert images_equal(im_merged, im_merged_expected) + + +@pytest.fixture +def stitch_test_pdf(patch_image_metadata_pairs, width, height): + + pdf = fpdf.FPDF(unit="pt", format=(width, height)) + + for pair in patch_image_metadata_pairs: + add_image(pdf, pair) + + return pdf.output(dest="S").encode("latin1") + + +@pytest.fixture +def patch_image_metadata_pairs(patches_metadata) -> List[ImageMetadataPair]: + images = map(random_single_color_image_from_metadata, patches_metadata) + return list(starmap(ImageMetadataPair, zip(images, patches_metadata))) + + +@pytest.fixture +def patches_metadata(base_patch_metadata, noise, split_count): + patches_metadata = list(BoxSplitter(noise).split_box(base_patch_metadata, split_count)) + return patches_metadata + + +@pytest.fixture(params=[(0, 0)]) +def noise(request): + return request.param + + +@pytest.fixture(params=[5]) +def split_count(request): + return request.param diff --git a/test/unit_tests/label_mapper_test.py b/test/unit_tests/label_mapper_test.py new file mode 100644 index 0000000..7312682 --- /dev/null +++ b/test/unit_tests/label_mapper_test.py @@ -0,0 +1,21 @@ +import pytest + +from image_prediction.exceptions import UnexpectedLabelFormat +from image_prediction.label_mapper.mappers.numeric import IndexMapper +from image_prediction.label_mapper.mappers.probability import ProbabilityMapper + + +def test_index_label_mapper(batch_of_expected_numeric_labels, batch_of_expected_string_labels, classes): + mapper = IndexMapper(classes) + assert list(mapper(batch_of_expected_numeric_labels)) == batch_of_expected_string_labels + with pytest.raises(UnexpectedLabelFormat): + list(mapper([len(classes)])) + + +def test_array_label_mapper( + batch_of_expected_probability_arrays, batch_of_expected_label_to_probability_mappings, classes +): + mapper = ProbabilityMapper(classes) + assert list(mapper(batch_of_expected_probability_arrays)) == batch_of_expected_label_to_probability_mappings + with pytest.raises(UnexpectedLabelFormat): + list(mapper([[0] * len(classes) + [1]])) diff --git a/test/unit_tests/mocked_server_test.py b/test/unit_tests/mocked_server_test.py new file mode 100644 index 0000000..d64c937 --- /dev/null +++ b/test/unit_tests/mocked_server_test.py @@ -0,0 +1,48 @@ +import json + +import pytest + +from image_prediction.exceptions import IntentionalTestException +from image_prediction.flask import make_prediction_server + + +def predict_fn(x: bytes): + x = int(x.decode()) + if x == 42: + return True + else: + raise IntentionalTestException("This is intended.") + + +@pytest.fixture +def server(): + server = make_prediction_server(predict_fn) + server.config.update({"TESTING": True}) + return server + + +@pytest.fixture +def client(server): + return server.test_client() + + +def test_server_predict_success(client, mute_logger): + response = client.post("/predict", data="42") + assert json.loads(response.data) + + +def test_server_predict_failure(client, mute_logger): + response = client.post("/predict", data="13") + assert response.status_code == 500 + + +def test_server_health_check(client): + response = client.get("/health") + assert response.status_code == 200 + assert response.json == "OK" + + +def test_server_ready_check(client): + response = client.get("/ready") + assert response.status_code == 200 + assert response.json == "OK" diff --git a/test/unit_tests/model_loader_test.py b/test/unit_tests/model_loader_test.py new file mode 100644 index 0000000..30fb679 --- /dev/null +++ b/test/unit_tests/model_loader_test.py @@ -0,0 +1,21 @@ +import pytest + +from image_prediction.redai_adapter.model import PredictionModelHandle + + +@pytest.mark.parametrize("database_type", ["mock"]) +def test_load_model_and_classes(model_loader, model_database_record_identifier, model, classes): + model_loaded = model_loader.load_model(model_database_record_identifier) + classes_loaded = model_loader.load_classes(model_database_record_identifier) + + assert model_loaded == model + assert classes_loaded == classes + + +@pytest.mark.parametrize("database_type", ["mlflow"]) +def test_load_model_and_classes_from_mlflow_store(model_loader, mlflow_run_id): + model_loaded = model_loader.load_model(mlflow_run_id) + classes_loaded = model_loader.load_classes(mlflow_run_id) + + assert type(model_loaded) == PredictionModelHandle + assert classes_loaded == ["formula", "logo", "other", "signature"] diff --git a/test/unit_tests/pipeline_test.py b/test/unit_tests/pipeline_test.py new file mode 100644 index 0000000..8d2208b --- /dev/null +++ b/test/unit_tests/pipeline_test.py @@ -0,0 +1,7 @@ +from image_prediction.pipeline import load_pipeline + + +def test_pipeline(real_pdf, real_expected_service_response): + pipeline = load_pipeline(verbose=False) + response = list(pipeline(real_pdf)) + assert response == real_expected_service_response diff --git a/test/unit_tests/preprocessor_test.py b/test/unit_tests/preprocessor_test.py new file mode 100644 index 0000000..b6315e3 --- /dev/null +++ b/test/unit_tests/preprocessor_test.py @@ -0,0 +1,38 @@ +import numpy as np +import pytest +from PIL import Image + +from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor +from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor +from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor, images_to_batch_tensor + + +def image_conversion_is_correct(image): + tensor = image_to_normalized_tensor(image) + image_re = Image.fromarray(np.uint8(tensor * 255), mode="RGB") + return image == image_re and tensor.ndim == 3 + + +def images_conversion_is_correct(images, tensor): + if not (images or tensor.size > 0): + return True + return all([isinstance(tensor, np.ndarray), tensor.ndim == 4, tensor.shape[0] == len(images)]) + + +def test_image_to_tensor(images): + assert all(map(image_conversion_is_correct, images)) + + +def test_images_to_batch_tensor(images): + tensor = images_to_batch_tensor(images) + assert images_conversion_is_correct(images, tensor) + + +def test_basic_preprocessor(images): + tensor = BasicPreprocessor()(images) + assert images_conversion_is_correct(images, tensor) + + +def test_identity_preprocessor(images): + images_preprocessed = IdentityPreprocessor()(images) + assert images_preprocessed == images diff --git a/test/unit_tests/split_mapper_test.py b/test/unit_tests/split_mapper_test.py new file mode 100644 index 0000000..cb5759d --- /dev/null +++ b/test/unit_tests/split_mapper_test.py @@ -0,0 +1,50 @@ +from image_prediction.info import Info +from image_prediction.stitching.split_mapper import VerticalSplitMapper, HorizontalSplitMapper + + +def test_split_vertical_mapper(base_patch_metadata): + sm = VerticalSplitMapper(base_patch_metadata) + + sm.c1 += 10 + 3 + sm.c2 += 20 + 3 + sm.dim += 20 + 3 + smw = sm.wrapped + + assert smw[Info.Y1] == sm.c1 == base_patch_metadata[Info.Y1] + 10 + 3 + assert smw[Info.Y2] == sm.c2 == base_patch_metadata[Info.Y2] + 20 + 3 + assert smw[Info.HEIGHT] == sm.dim == base_patch_metadata[Info.HEIGHT] + 20 + 3 + + sm = VerticalSplitMapper(base_patch_metadata) + + sm.c1 = 10 + 3 + sm.c2 = 20 + 3 + sm.dim = 20 + 3 + smw = sm.wrapped + + assert smw[Info.Y1] == sm.c1 == 10 + 3 + assert smw[Info.Y2] == sm.c2 == 20 + 3 + assert smw[Info.HEIGHT] == sm.dim == 20 + 3 + + +def test_split_horizontal_mapper(base_patch_metadata): + sm = HorizontalSplitMapper(base_patch_metadata) + + sm.c1 += 10 + 3 + sm.c2 += 20 + 3 + sm.dim += 20 + 3 + smw = sm.wrapped + + assert smw[Info.X1] == sm.c1 == base_patch_metadata[Info.X1] + 10 + 3 + assert smw[Info.X2] == sm.c2 == base_patch_metadata[Info.X2] + 20 + 3 + assert smw[Info.WIDTH] == sm.dim == base_patch_metadata[Info.WIDTH] + 20 + 3 + + sm = HorizontalSplitMapper(base_patch_metadata) + + sm.c1 = 10 + 3 + sm.c2 = 20 + 3 + sm.dim = 20 + 3 + smw = sm.wrapped + + assert smw[Info.X1] == sm.c1 == 10 + 3 + assert smw[Info.X2] == sm.c2 == 20 + 3 + assert smw[Info.WIDTH] == sm.dim == 20 + 3 diff --git a/test/unit_tests/test_predictor.py b/test/unit_tests/test_predictor.py deleted file mode 100644 index 0da6f91..0000000 --- a/test/unit_tests/test_predictor.py +++ /dev/null @@ -1,26 +0,0 @@ -def test_predict_pdf_works(predictor, test_pdf): - # FIXME ugly test since there are '\n's in the dict with unknown heritage - predictions, metadata = predictor.predict_pdf(test_pdf) - predictions = [p for p in predictions][0] - assert predictions["class"] == "formula" - probabilities = predictions["probabilities"] - # Floating point precision problem for output so test only that keys exist not the values - assert all(key in probabilities for key in ("formula", "other", "signature", "logo")) - metadata = list(metadata) - metadata = dict(**metadata[0]) - metadata.pop("document_filename") # temp filename cannot be tested - assert metadata == { - "px_width": 389.0, - "px_height": 389.0, - "width": 194.49999000000003, - "height": 194.49998999999997, - "x1": 320.861, - "x2": 515.36099, - "y1": 347.699, - "y2": 542.19899, - "page_width": 595.2800000000001, - "page_height": 841.89, - "page_rotation": 0, - "page_idx": 1, - "n_pages": 3, - } diff --git a/test/unit_tests/test_response.py b/test/unit_tests/test_response.py deleted file mode 100644 index 696c92b..0000000 --- a/test/unit_tests/test_response.py +++ /dev/null @@ -1,5 +0,0 @@ -from image_prediction.response import build_response - - -def test_build_response_returns_valid_response(predictions, metadata, response): - assert build_response(predictions, metadata) == response diff --git a/test/unit_tests/utils_test.py b/test/unit_tests/utils_test.py new file mode 100644 index 0000000..5263af8 --- /dev/null +++ b/test/unit_tests/utils_test.py @@ -0,0 +1,8 @@ +from image_prediction.utils.generic import until + + +def test_until(): + def f(x): + return x / 2 + + assert until(lambda x: x == 0, f, 1) == 0 diff --git a/test/utils/__init__.py b/test/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/utils/comparison.py b/test/utils/comparison.py new file mode 100644 index 0000000..f2677ce --- /dev/null +++ b/test/utils/comparison.py @@ -0,0 +1,40 @@ +from itertools import starmap, product, repeat +from typing import Iterable + +import numpy as np +from PIL.Image import Image +from frozendict import frozendict +from funcy import ilen + +from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor + + +def transform_equal(a, b): + return (list(a) if isinstance(a, map) else a) == b + + +def images_equal(im1: Image, im2: Image, **kwargs): + return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs) + + +def metadata_equal(mdat1: Iterable[dict], mdat2: Iterable[dict]): + return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2)) + + +def image_sets_equal(ims1: Iterable[Image], ims2: Iterable[Image]): + ims1, ims2 = map(lambda x: sorted(map(image_to_normalized_tensor, x), key=np.mean), (ims1, ims2)) + + n = len(ims1) + assert isinstance(ims1, list) + assert len(ims2) == n + + used = set() + covered = set() + + for im1i, im2i in product(*repeat(range(n), 2)): + + if im1i not in covered and im2i not in used and images_equal(ims1[im1i], ims2[im2i]): + covered.add(im1i) + used.add(im2i) + + return len(covered) == len(used) == n diff --git a/test/utils/generation/__init__.py b/test/utils/generation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/utils/generation/image.py b/test/utils/generation/image.py new file mode 100644 index 0000000..c588bf7 --- /dev/null +++ b/test/utils/generation/image.py @@ -0,0 +1,31 @@ +import numpy as np +from PIL import Image + +from image_prediction.info import Info + + +def random_single_color_image_from_metadata(metadata): + image = Image.new( + "RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=tuple(map(int, np.random.uniform(size=3) * 255)) + ) + return image + + +def gray_image_from_metadata(metadata): + image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100)) + return image + + +def array_to_image(array): + assert np.all(array <= 1) + assert np.all(array >= 0) + + if array.shape[-1] == 3: + mode = "RGB" + elif array.shape[-1] == 4: + mode = "RGBA" + else: + raise ValueError(f"Unexpected number of channels {array.shape[-1]}. Expected 3 or 4.") + + # noinspection PyTypeChecker + return Image.fromarray(np.uint8(array * 255), mode=mode) diff --git a/test/utils/generation/pdf.py b/test/utils/generation/pdf.py new file mode 100644 index 0000000..852647e --- /dev/null +++ b/test/utils/generation/pdf.py @@ -0,0 +1,30 @@ +import tempfile +from operator import itemgetter + +import fpdf + +from image_prediction.info import Info + + +def add_image(pdf, image_metadata_pair, suffix="png"): + while fewer_pages_then_required(image_metadata_pair.metadata[Info.PAGE_IDX], pdf): + pdf.add_page() + + add_image_to_last_page(pdf, image_metadata_pair, suffix=suffix) + + +def fewer_pages_then_required(page_idx, pdf): + return page_idx > pdf.page - 1 + + +def pdf_stream(pdf: fpdf.fpdf.FPDF): + return pdf.output(dest="S").encode("latin1") + + +def add_image_to_last_page(pdf: fpdf.fpdf.FPDF, image_metadata_pair, suffix): + image, metadata = image_metadata_pair + x, y, w, h = itemgetter(Info.X1, Info.Y1, Info.WIDTH, Info.HEIGHT)(metadata) + + with tempfile.NamedTemporaryFile(suffix=f".{suffix}") as temp_image: + image.save(temp_image.name) + pdf.image(temp_image.name, x=x, y=y, w=w, h=h, type=suffix) diff --git a/test/utils/label.py b/test/utils/label.py new file mode 100644 index 0000000..b3fe2d9 --- /dev/null +++ b/test/utils/label.py @@ -0,0 +1,2 @@ +def map_labels(numeric_labels, classes): + return [classes[nl] for nl in numeric_labels] diff --git a/test/utils/metadata.py b/test/utils/metadata.py new file mode 100644 index 0000000..cf67451 --- /dev/null +++ b/test/utils/metadata.py @@ -0,0 +1,11 @@ +from image_prediction.info import Info + + +def get_base_position_metadata(width, height, page_width, page_height): + return { + Info.WIDTH: width, + Info.HEIGHT: height, + Info.PAGE_IDX: 0, + Info.PAGE_WIDTH: page_width, + Info.PAGE_HEIGHT: page_height, + } diff --git a/test/utils/stitching.py b/test/utils/stitching.py new file mode 100644 index 0000000..46f40a2 --- /dev/null +++ b/test/utils/stitching.py @@ -0,0 +1,80 @@ +import random +from copy import deepcopy +from itertools import chain + +from funcy import rpartial, juxt + +from image_prediction.stitching.split_mapper import SplitMapper, HorizontalSplitMapper, VerticalSplitMapper +from image_prediction.stitching.utils import validate_box + + +class BoxSplitter: + def __init__(self, noise=None): + self.__steps = None + self.__noise = (0, 0) if not noise else noise + + def split_box(self, box, steps=5): + self.__steps = steps + return self.__split_recursively(box, 0) + + def __split_recursively(self, box, step): + return self.__split_and_recurse(box, step) if self.__steps_left(step) else self.__base_case(box) + + def __steps_left(self, step): + return step < self.__steps + + @staticmethod + def __base_case(box): + return [box] + + def __split_and_recurse(self, box, step): + new_boxes = self.__random_split(box) + new_boxes_per_branch = self.__tree_recurse(new_boxes, step + 1) + return chain.from_iterable(new_boxes_per_branch) + + def __random_split(self, box): + splitter = random.choice([self.__split_horizontal, self.__split_vertical]) + new_boxes = splitter(box) + return new_boxes + + def __tree_recurse(self, boxes, step): + return map(rpartial(self.__split_recursively, step + 1), boxes) + + def __split_horizontal(self, box): + return self.__split_if_large_enough(HorizontalSplitMapper(box)) + + def __split_vertical(self, box): + return self.__split_if_large_enough(VerticalSplitMapper(box)) + + def __split_if_large_enough(self, wrapped_box: SplitMapper): + return ( + self.__get_child_boxes(wrapped_box) + if self.__large_enough(wrapped_box) + else self.__base_case(wrapped_box.wrapped) + ) + + def noise(self): + return int(round(random.uniform(*self.__noise))) + + @staticmethod + def __large_enough(wrapped_box: SplitMapper): + return wrapped_box.dim >= 10 + + def __get_child_boxes(self, wrapped_box: SplitMapper): + + split_len = random.randint(5, wrapped_box.dim - 5) + split_point = wrapped_box.c1 + split_len + + box_left, box_right = juxt(deepcopy, deepcopy)(wrapped_box) + + noise = -self.noise() + box_left.dim = split_len + noise + box_right.dim = wrapped_box.dim - split_len + + box_left.c2 = split_point + noise + box_right.c1 = split_point + + validate_box(box_left.wrapped) + validate_box(box_right.wrapped) + + return box_left.wrapped, box_right.wrapped