From b3a58d6777633227b5fbe706c107e8c7e56e8539 Mon Sep 17 00:00:00 2001 From: Julius Unverfehrt Date: Fri, 30 Aug 2024 10:55:50 +0200 Subject: [PATCH] chore: add tests to ensure no regression happens ever again --- .../image_classification_test.py | 21 +++++++++++++++++++ test/regressions_tests/image_hashing_test.py | 18 ++++++++++++++++ test/unit_tests/encoder_test.py | 16 -------------- 3 files changed, 39 insertions(+), 16 deletions(-) create mode 100644 test/regressions_tests/image_classification_test.py create mode 100644 test/regressions_tests/image_hashing_test.py diff --git a/test/regressions_tests/image_classification_test.py b/test/regressions_tests/image_classification_test.py new file mode 100644 index 0000000..46c1265 --- /dev/null +++ b/test/regressions_tests/image_classification_test.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from funcy import first + +from image_prediction.config import CONFIG +from image_prediction.pipeline import load_pipeline + + +def test_image_classification_does_not_regress(): + """See RED-9948: the predictions unexpectedly changed for some images. In the end the issue is the tensorflow + version. We ensure that the prediction of the image with the hash FA30F080F0C031CE17E8CF237 is inconclusive, + and that the flag all_passed is false.""" + pdf_path = Path(__file__).parents[1] / "data" / "RED-9948" / "SYNGENTA_EFSA_sanitisation_GFL_v2.pdf" + pdf_bytes = pdf_path.read_bytes() + + pipeline = load_pipeline(verbose=True, batch_size=CONFIG.service.batch_size) + predictions = list(pipeline(pdf_bytes)) + predictions = first([x for x in predictions if x["representation"] == "FA30F080F0C031CE17E8CF237"]) + + assert predictions["filters"]["allPassed"] is False + assert predictions["filters"]["probability"]["unconfident"] is True diff --git a/test/regressions_tests/image_hashing_test.py b/test/regressions_tests/image_hashing_test.py new file mode 100644 index 0000000..b327352 --- /dev/null +++ b/test/regressions_tests/image_hashing_test.py @@ -0,0 +1,18 @@ +from pathlib import Path + +from image_prediction.encoder.encoders.hash_encoder import HashEncoder +from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor + + +def test_all_hashes_have_length_of_twentyfive(): + """See RED-3814: all hashes should have 25 characters.""" + pdf_path = Path(__file__).parents[1] / "data" / "RED-3814" / "similarImages2.pdf" + pdf_bytes = pdf_path.read_bytes() + image_extractor = ParsablePDFImageExtractor() + image_metadata_pairs = list(image_extractor.extract(pdf_bytes)) + images = [image for image, _ in image_metadata_pairs] + + hash_encoder = HashEncoder() + hashes = list(hash_encoder.encode(images)) + + assert all(len(h) == 25 for h in hashes) diff --git a/test/unit_tests/encoder_test.py b/test/unit_tests/encoder_test.py index edabbba..c38834a 100644 --- a/test/unit_tests/encoder_test.py +++ b/test/unit_tests/encoder_test.py @@ -1,7 +1,6 @@ import random from itertools import starmap from operator import __eq__ -from pathlib import Path import pytest from PIL.Image import Image @@ -9,7 +8,6 @@ from funcy import compose, first from image_prediction.encoder.encoders.hash_encoder import HashEncoder from image_prediction.encoder.encoders.hash_encoder import hash_image -from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.utils.generic import lift @@ -32,17 +30,3 @@ def test_hash_encoder(images, hashed_images, base_patch_image): hashed_resized = compose(first, encoder, lift(resize))([base_patch_image]) hashed = hash_image(base_patch_image) assert close(hashed_resized, hashed) - - -def test_all_hashes_have_length_of_twentyfive(): - """See RED-3814: all hashes should have 25 characters.""" - pdf_path = Path(__file__).parents[1] / "data" / "RED-3814" / "similarImages2.pdf" - pdf_bytes = pdf_path.read_bytes() - image_extractor = ParsablePDFImageExtractor() - image_metadata_pairs = list(image_extractor.extract(pdf_bytes)) - images = [image for image, _ in image_metadata_pairs] - - hash_encoder = HashEncoder() - hashes = list(hash_encoder.encode(images)) - - assert all(len(h) == 25 for h in hashes)