Merge branch 'master' of ssh://git.iqser.com:2222/rr/image-prediction

This commit is contained in:
Matthias Bisping 2022-05-12 11:51:30 +02:00
commit 4efb9c79b1
13 changed files with 100 additions and 11 deletions

View File

@ -3,6 +3,7 @@ from funcy import juxt
from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.compositor.compositor import TransformerCompositor from image_prediction.compositor.compositor import TransformerCompositor
from image_prediction.encoder.encoders.hash_encoder import HashEncoder
from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
from image_prediction.formatter.formatters.enum import EnumFormatter from image_prediction.formatter.formatters.enum import EnumFormatter
@ -36,3 +37,7 @@ def get_formatter():
PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter() PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()
) )
return formatter return formatter
def get_encoder():
return HashEncoder()

View File

View File

@ -0,0 +1,13 @@
import abc
from typing import Iterable
from PIL.Image import Image
class Encoder(abc.ABC):
@abc.abstractmethod
def encode(self, images: Iterable[Image]):
raise NotImplementedError
def __call__(self, images: Iterable[Image], batch_size=16):
yield from self.encode(images)

View File

@ -0,0 +1,24 @@
from typing import Iterable
from PIL import Image
from image_prediction.encoder.encoder import Encoder
class HashEncoder(Encoder):
def encode(self, images: Iterable[Image.Image]):
yield from map(hash_image, images)
def __call__(self, images: Iterable[Image.Image], batch_size=16):
yield from self.encode(images)
def hash_image(image: Image.Image):
"""See: https://stackoverflow.com/a/49692185/3578468"""
image = image.resize((10, 10), Image.ANTIALIAS)
image = image.convert("L")
pixel_data = list(image.getdata())
avg_pixel = sum(pixel_data) / len(pixel_data)
bits = "".join(["1" if (px >= avg_pixel) else "0" for px in pixel_data])
hex_representation = str(hex(int(bits, 2)))[2:][::-1].upper()
return hex_representation

View File

@ -2,11 +2,17 @@ import os
from functools import partial from functools import partial
from itertools import chain, tee from itertools import chain, tee
from funcy import rcompose, first, compose, second, chunks, identity from funcy import rcompose, first, compose, second, chunks, identity, rpartial
from tqdm import tqdm from tqdm import tqdm
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor from image_prediction.default_objects import (
get_formatter,
get_mlflow_model_loader,
get_image_classifier,
get_extractor,
get_encoder,
)
from image_prediction.locations import MLRUNS_DIR from image_prediction.locations import MLRUNS_DIR
from image_prediction.utils.generic import lift, starlift from image_prediction.utils.generic import lift, starlift
@ -37,20 +43,21 @@ class Pipeline:
extract = get_extractor(**kwargs) extract = get_extractor(**kwargs)
classifier = get_image_classifier(model_loader, model_identifier) classifier = get_image_classifier(model_loader, model_identifier)
reformat = get_formatter() reformat = get_formatter()
represent = get_encoder()
split = compose(star(parallel(*map(lift, (first, second)))), tee) split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3))
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size)) classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
pairwise_apply = compose(star, parallel) pairwise_apply = compose(star, parallel)
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip)) join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip))
# +>--classify--v # />--classify--\
# --extract-->--split--| |--join-->reformat # --extract-->--split--+->--encode---->+--join-->reformat
# +>--identity--^ # \>--identity--/
self.pipe = rcompose( self.pipe = rcompose(
extract, # ... image-metadata-pairs as a stream extract, # ... image-metadata-pairs as a stream
split, # ... into an image stream and a metadata stream split, # ... into an image stream and a metadata stream
pairwise_apply(classify, identity), # ... apply functions to the streams pairwise pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
join, # ... the streams by zipping join, # ... the streams by zipping
reformat, # ... the items reformat, # ... the items
) )

View File

@ -36,11 +36,13 @@ def build_image_info(data: dict) -> dict:
) )
classification = data["classification"] classification = data["classification"]
representation = data["representation"]
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
image_info = { image_info = {
"classification": classification, "classification": classification,
"representation": representation,
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1}, "position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1},
"geometry": {"width": width, "height": height}, "geometry": {"width": width, "height": height},
"alpha": alpha, "alpha": alpha,

View File

@ -17,6 +17,7 @@ pytest_plugins = [
"test.fixtures.parameters", "test.fixtures.parameters",
"test.fixtures.pdf", "test.fixtures.pdf",
"test.fixtures.target", "test.fixtures.target",
"test.unit_tests.image_stitching_test"
] ]

View File

@ -9,6 +9,7 @@
"signature": 0.0 "signature": 0.0
} }
}, },
"representation": "FFFEF0C7033648170F3EFFFFF",
"position": { "position": {
"x1": 321, "x1": 321,
"x2": 515, "x2": 515,

View File

@ -6,8 +6,9 @@ from operator import itemgetter
import numpy as np import numpy as np
import pytest import pytest
from funcy import rcompose from funcy import rcompose, lmap
from image_prediction.encoder.encoders.hash_encoder import hash_image
from image_prediction.exceptions import UnknownLabelFormat from image_prediction.exceptions import UnknownLabelFormat
from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys
from image_prediction.locations import TEST_DATA_DIR from image_prediction.locations import TEST_DATA_DIR
@ -54,7 +55,7 @@ def batch_of_expected_label_to_probability_mappings(batch_of_expected_probabilit
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob} return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
rounder = rcompose(partial(np.round, decimals=4), float) rounder = rcompose(partial(np.round, decimals=4), float)
return list(map(map_probabilities, batch_of_expected_probability_arrays)) return lmap(map_probabilities, batch_of_expected_probability_arrays)
@pytest.fixture @pytest.fixture
@ -89,3 +90,8 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
def real_expected_service_response(): def real_expected_service_response():
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f: with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
yield json.load(f) yield json.load(f)
@pytest.fixture
def hashed_images(images):
return lmap(hash_image, images)

View File

@ -77,7 +77,6 @@ def server_process(server, host_and_port, url):
@pytest.mark.parametrize("server_type", ["actual"]) @pytest.mark.parametrize("server_type", ["actual"])
@pytest.mark.skip()
def test_server_predict(url, real_pdf, real_expected_service_response): def test_server_predict(url, real_pdf, real_expected_service_response):
response = requests.post(f"{url}/predict", data=real_pdf) response = requests.post(f"{url}/predict", data=real_pdf)
response.raise_for_status() response.raise_for_status()

View File

View File

@ -0,0 +1,31 @@
import random
from itertools import starmap
from operator import __eq__
import pytest
from PIL.Image import Image
from funcy import compose, first
from image_prediction.encoder.encoders.hash_encoder import HashEncoder, hash_image
from image_prediction.utils.generic import lift
def resize(image: Image):
factor = random.uniform(0.3, 2)
new_size = map(lambda x: int(x * factor), image.size)
return image.resize(new_size)
def close(a: str, b: str):
assert len(a) == len(b)
return sum(starmap(__eq__, zip(a, b))) / len(a) >= 0.75
@pytest.mark.xfail(reason="Stochastic test, may fail some amount of the time.")
def test_hash_encoder(images, hashed_images, base_patch_image):
encoder = HashEncoder()
assert list(encoder(images)) == hashed_images
hashed_resized = compose(first, encoder, lift(resize))([base_patch_image])
hashed = hash_image(base_patch_image)
assert close(hashed_resized, hashed)