diff --git a/image_prediction/default_objects.py b/image_prediction/default_objects.py index 97c089e..1c40d56 100644 --- a/image_prediction/default_objects.py +++ b/image_prediction/default_objects.py @@ -3,6 +3,7 @@ from funcy import juxt from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.compositor.compositor import TransformerCompositor +from image_prediction.encoder.encoders.hash_encoder import HashEncoder from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter from image_prediction.formatter.formatters.enum import EnumFormatter @@ -36,3 +37,7 @@ def get_formatter(): PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter() ) return formatter + + +def get_encoder(): + return HashEncoder() diff --git a/image_prediction/encoder/__init__.py b/image_prediction/encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/encoder/encoder.py b/image_prediction/encoder/encoder.py new file mode 100644 index 0000000..92f5f63 --- /dev/null +++ b/image_prediction/encoder/encoder.py @@ -0,0 +1,13 @@ +import abc +from typing import Iterable + +from PIL.Image import Image + + +class Encoder(abc.ABC): + @abc.abstractmethod + def encode(self, images: Iterable[Image]): + raise NotImplementedError + + def __call__(self, images: Iterable[Image], batch_size=16): + yield from self.encode(images) diff --git a/image_prediction/encoder/encoders/__init__.py b/image_prediction/encoder/encoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/image_prediction/encoder/encoders/hash_encoder.py b/image_prediction/encoder/encoders/hash_encoder.py new file mode 100644 index 0000000..c21bacd --- /dev/null +++ b/image_prediction/encoder/encoders/hash_encoder.py @@ -0,0 +1,24 @@ +from typing import Iterable + +from PIL import Image + +from image_prediction.encoder.encoder import Encoder + + +class HashEncoder(Encoder): + def encode(self, images: Iterable[Image.Image]): + yield from map(hash_image, images) + + def __call__(self, images: Iterable[Image.Image], batch_size=16): + yield from self.encode(images) + + +def hash_image(image: Image.Image): + """See: https://stackoverflow.com/a/49692185/3578468""" + image = image.resize((10, 10), Image.ANTIALIAS) + image = image.convert("L") + pixel_data = list(image.getdata()) + avg_pixel = sum(pixel_data) / len(pixel_data) + bits = "".join(["1" if (px >= avg_pixel) else "0" for px in pixel_data]) + hex_representation = str(hex(int(bits, 2)))[2:][::-1].upper() + return hex_representation diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 721cbb9..6d29ac7 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -2,11 +2,17 @@ import os from functools import partial from itertools import chain, tee -from funcy import rcompose, first, compose, second, chunks, identity +from funcy import rcompose, first, compose, second, chunks, identity, rpartial from tqdm import tqdm from image_prediction.config import CONFIG -from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor +from image_prediction.default_objects import ( + get_formatter, + get_mlflow_model_loader, + get_image_classifier, + get_extractor, + get_encoder, +) from image_prediction.locations import MLRUNS_DIR from image_prediction.utils.generic import lift, starlift @@ -37,20 +43,21 @@ class Pipeline: extract = get_extractor(**kwargs) classifier = get_image_classifier(model_loader, model_identifier) reformat = get_formatter() + represent = get_encoder() - split = compose(star(parallel(*map(lift, (first, second)))), tee) + split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3)) classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) pairwise_apply = compose(star, parallel) - join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip)) + join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip)) - # +>--classify--v - # --extract-->--split--| |--join-->reformat - # +>--identity--^ + # />--classify--\ + # --extract-->--split--+->--encode---->+--join-->reformat + # \>--identity--/ self.pipe = rcompose( extract, # ... image-metadata-pairs as a stream split, # ... into an image stream and a metadata stream - pairwise_apply(classify, identity), # ... apply functions to the streams pairwise + pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise join, # ... the streams by zipping reformat, # ... the items ) diff --git a/image_prediction/transformer/transformers/response.py b/image_prediction/transformer/transformers/response.py index ca8ce99..3e35104 100644 --- a/image_prediction/transformer/transformers/response.py +++ b/image_prediction/transformer/transformers/response.py @@ -36,11 +36,13 @@ def build_image_info(data: dict) -> dict: ) classification = data["classification"] + representation = data["representation"] min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) image_info = { "classification": classification, + "representation": representation, "position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1}, "geometry": {"width": width, "height": height}, "alpha": alpha, diff --git a/test/conftest.py b/test/conftest.py index ee2b3d5..65298b0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -17,6 +17,7 @@ pytest_plugins = [ "test.fixtures.parameters", "test.fixtures.pdf", "test.fixtures.target", + "test.unit_tests.image_stitching_test" ] diff --git a/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json index a2171bb..1a1b3f5 100644 --- a/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json +++ b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json @@ -9,6 +9,7 @@ "signature": 0.0 } }, + "representation": "FFFEF0C7033648170F3EFFFFF", "position": { "x1": 321, "x2": 515, diff --git a/test/fixtures/target.py b/test/fixtures/target.py index bcc0e75..23f23bd 100644 --- a/test/fixtures/target.py +++ b/test/fixtures/target.py @@ -6,8 +6,9 @@ from operator import itemgetter import numpy as np import pytest -from funcy import rcompose +from funcy import rcompose, lmap +from image_prediction.encoder.encoders.hash_encoder import hash_image from image_prediction.exceptions import UnknownLabelFormat from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys from image_prediction.locations import TEST_DATA_DIR @@ -54,7 +55,7 @@ def batch_of_expected_label_to_probability_mappings(batch_of_expected_probabilit return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob} rounder = rcompose(partial(np.round, decimals=4), float) - return list(map(map_probabilities, batch_of_expected_probability_arrays)) + return lmap(map_probabilities, batch_of_expected_probability_arrays) @pytest.fixture @@ -89,3 +90,8 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped): def real_expected_service_response(): with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f: yield json.load(f) + + +@pytest.fixture +def hashed_images(images): + return lmap(hash_image, images) diff --git a/test/integration_tests/actual_server_test.py b/test/integration_tests/actual_server_test.py index d19a6b6..9983acf 100644 --- a/test/integration_tests/actual_server_test.py +++ b/test/integration_tests/actual_server_test.py @@ -77,7 +77,6 @@ def server_process(server, host_and_port, url): @pytest.mark.parametrize("server_type", ["actual"]) -@pytest.mark.skip() def test_server_predict(url, real_pdf, real_expected_service_response): response = requests.post(f"{url}/predict", data=real_pdf) response.raise_for_status() diff --git a/test/unit_tests/__init__.py b/test/unit_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit_tests/encoder_test.py b/test/unit_tests/encoder_test.py new file mode 100644 index 0000000..5102aca --- /dev/null +++ b/test/unit_tests/encoder_test.py @@ -0,0 +1,31 @@ +import random +from itertools import starmap +from operator import __eq__ + +import pytest +from PIL.Image import Image +from funcy import compose, first + +from image_prediction.encoder.encoders.hash_encoder import HashEncoder, hash_image +from image_prediction.utils.generic import lift + + +def resize(image: Image): + factor = random.uniform(0.3, 2) + new_size = map(lambda x: int(x * factor), image.size) + return image.resize(new_size) + + +def close(a: str, b: str): + assert len(a) == len(b) + return sum(starmap(__eq__, zip(a, b))) / len(a) >= 0.75 + + +@pytest.mark.xfail(reason="Stochastic test, may fail some amount of the time.") +def test_hash_encoder(images, hashed_images, base_patch_image): + encoder = HashEncoder() + assert list(encoder(images)) == hashed_images + + hashed_resized = compose(first, encoder, lift(resize))([base_patch_image]) + hashed = hash_image(base_patch_image) + assert close(hashed_resized, hashed)