added hash image encoder that produces representations by hashing

This commit is contained in:
Matthias Bisping 2022-05-12 11:15:48 +02:00
parent 84a8b0a290
commit 41d94199ed
12 changed files with 97 additions and 7 deletions

View File

@ -3,6 +3,7 @@ from funcy import juxt
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.compositor.compositor import TransformerCompositor
from image_prediction.encoder.encoders.hash_encoder import HashEncoder
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
from image_prediction.formatter.formatters.enum import EnumFormatter
@ -36,3 +37,7 @@ def get_formatter():
PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()
)
return formatter
def get_encoder():
return HashEncoder()

View File

View File

@ -0,0 +1,13 @@
import abc
from typing import Iterable
from PIL.Image import Image
class Encoder(abc.ABC):
@abc.abstractmethod
def encode(self, images: Iterable[Image]):
raise NotImplementedError
def __call__(self, images: Iterable[Image], batch_size=16):
yield from self.encode(images)

View File

@ -0,0 +1,24 @@
from typing import Iterable
from PIL import Image
from image_prediction.encoder.encoder import Encoder
class HashEncoder(Encoder):
def encode(self, images: Iterable[Image.Image]):
yield from map(hash_image, images)
def __call__(self, images: Iterable[Image.Image], batch_size=16):
yield from self.encode(images)
def hash_image(image: Image.Image):
"""See: https://stackoverflow.com/a/49692185/3578468"""
image = image.resize((10, 10), Image.ANTIALIAS)
image = image.convert("L")
pixel_data = list(image.getdata())
avg_pixel = sum(pixel_data) / len(pixel_data)
bits = "".join(["1" if (px >= avg_pixel) else "0" for px in pixel_data])
hex_representation = str(hex(int(bits, 2)))[2:][::-1].upper()
return hex_representation

View File

@ -2,11 +2,17 @@ import os
from functools import partial
from itertools import chain, tee
from funcy import rcompose, first, compose, second, chunks, identity
from funcy import rcompose, first, compose, second, chunks, identity, rpartial
from tqdm import tqdm
from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
from image_prediction.default_objects import (
get_formatter,
get_mlflow_model_loader,
get_image_classifier,
get_extractor,
get_encoder,
)
from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift
@ -37,11 +43,12 @@ class Pipeline:
extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter()
represent = get_encoder()
split = compose(star(parallel(*map(lift, (first, second)))), tee)
split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3))
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
pairwise_apply = compose(star, parallel)
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip))
join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip))
# +>--classify--v
# --extract-->--split--| |--join-->reformat
@ -50,7 +57,7 @@ class Pipeline:
self.pipe = rcompose(
extract, # ... image-metadata-pairs as a stream
split, # ... into an image stream and a metadata stream
pairwise_apply(classify, identity), # ... apply functions to the streams pairwise
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
join, # ... the streams by zipping
reformat, # ... the items
)

View File

@ -36,11 +36,13 @@ def build_image_info(data: dict) -> dict:
)
classification = data["classification"]
representation = data["representation"]
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
image_info = {
"classification": classification,
"representation": representation,
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1},
"geometry": {"width": width, "height": height},
"alpha": alpha,

View File

@ -17,6 +17,7 @@ pytest_plugins = [
"test.fixtures.parameters",
"test.fixtures.pdf",
"test.fixtures.target",
"test.unit_tests.image_stitching_test"
]

View File

@ -9,6 +9,7 @@
"signature": 0.0
}
},
"representation": "FFFEF0C7033648170F3EFFFFF",
"position": {
"x1": 321,
"x2": 515,

View File

@ -6,8 +6,9 @@ from operator import itemgetter
import numpy as np
import pytest
from funcy import rcompose
from funcy import rcompose, lmap
from image_prediction.encoder.encoders.hash_encoder import hash_image
from image_prediction.exceptions import UnknownLabelFormat
from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys
from image_prediction.locations import TEST_DATA_DIR
@ -54,7 +55,7 @@ def batch_of_expected_label_to_probability_mappings(batch_of_expected_probabilit
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
rounder = rcompose(partial(np.round, decimals=4), float)
return list(map(map_probabilities, batch_of_expected_probability_arrays))
return lmap(map_probabilities, batch_of_expected_probability_arrays)
@pytest.fixture
@ -89,3 +90,8 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
def real_expected_service_response():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
yield json.load(f)
@pytest.fixture
def hashed_images(images):
return lmap(hash_image, images)

View File

View File

@ -0,0 +1,31 @@
import random
from itertools import starmap
from operator import __eq__
import pytest
from PIL.Image import Image
from funcy import compose, first
from image_prediction.encoder.encoders.hash_encoder import HashEncoder, hash_image
from image_prediction.utils.generic import lift
def resize(image: Image):
factor = random.uniform(0.3, 2)
new_size = map(lambda x: int(x * factor), image.size)
return image.resize(new_size)
def close(a: str, b: str):
assert len(a) == len(b)
return sum(starmap(__eq__, zip(a, b))) / len(a) >= 0.75
@pytest.mark.xfail(reason="Stochastic test, may fail some amount of the time.")
def test_hash_encoder(images, hashed_images, base_patch_image):
encoder = HashEncoder()
assert list(encoder(images)) == hashed_images
hashed_resized = compose(first, encoder, lift(resize))([base_patch_image])
hashed = hash_image(base_patch_image)
assert close(hashed_resized, hashed)