Merge branch 'master' of ssh://git.iqser.com:2222/rr/image-prediction
This commit is contained in:
commit
4efb9c79b1
@ -3,6 +3,7 @@ from funcy import juxt
|
|||||||
from image_prediction.classifier.classifier import Classifier
|
from image_prediction.classifier.classifier import Classifier
|
||||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
from image_prediction.compositor.compositor import TransformerCompositor
|
from image_prediction.compositor.compositor import TransformerCompositor
|
||||||
|
from image_prediction.encoder.encoders.hash_encoder import HashEncoder
|
||||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||||
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter
|
||||||
from image_prediction.formatter.formatters.enum import EnumFormatter
|
from image_prediction.formatter.formatters.enum import EnumFormatter
|
||||||
@ -36,3 +37,7 @@ def get_formatter():
|
|||||||
PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()
|
PDFNetCoordinateTransformer(), EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()
|
||||||
)
|
)
|
||||||
return formatter
|
return formatter
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoder():
|
||||||
|
return HashEncoder()
|
||||||
|
|||||||
0
image_prediction/encoder/__init__.py
Normal file
0
image_prediction/encoder/__init__.py
Normal file
13
image_prediction/encoder/encoder.py
Normal file
13
image_prediction/encoder/encoder.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import abc
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def encode(self, images: Iterable[Image]):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, images: Iterable[Image], batch_size=16):
|
||||||
|
yield from self.encode(images)
|
||||||
0
image_prediction/encoder/encoders/__init__.py
Normal file
0
image_prediction/encoder/encoders/__init__.py
Normal file
24
image_prediction/encoder/encoders/hash_encoder.py
Normal file
24
image_prediction/encoder/encoders/hash_encoder.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from image_prediction.encoder.encoder import Encoder
|
||||||
|
|
||||||
|
|
||||||
|
class HashEncoder(Encoder):
|
||||||
|
def encode(self, images: Iterable[Image.Image]):
|
||||||
|
yield from map(hash_image, images)
|
||||||
|
|
||||||
|
def __call__(self, images: Iterable[Image.Image], batch_size=16):
|
||||||
|
yield from self.encode(images)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_image(image: Image.Image):
|
||||||
|
"""See: https://stackoverflow.com/a/49692185/3578468"""
|
||||||
|
image = image.resize((10, 10), Image.ANTIALIAS)
|
||||||
|
image = image.convert("L")
|
||||||
|
pixel_data = list(image.getdata())
|
||||||
|
avg_pixel = sum(pixel_data) / len(pixel_data)
|
||||||
|
bits = "".join(["1" if (px >= avg_pixel) else "0" for px in pixel_data])
|
||||||
|
hex_representation = str(hex(int(bits, 2)))[2:][::-1].upper()
|
||||||
|
return hex_representation
|
||||||
@ -2,11 +2,17 @@ import os
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import chain, tee
|
from itertools import chain, tee
|
||||||
|
|
||||||
from funcy import rcompose, first, compose, second, chunks, identity
|
from funcy import rcompose, first, compose, second, chunks, identity, rpartial
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from image_prediction.config import CONFIG
|
from image_prediction.config import CONFIG
|
||||||
from image_prediction.default_objects import get_formatter, get_mlflow_model_loader, get_image_classifier, get_extractor
|
from image_prediction.default_objects import (
|
||||||
|
get_formatter,
|
||||||
|
get_mlflow_model_loader,
|
||||||
|
get_image_classifier,
|
||||||
|
get_extractor,
|
||||||
|
get_encoder,
|
||||||
|
)
|
||||||
from image_prediction.locations import MLRUNS_DIR
|
from image_prediction.locations import MLRUNS_DIR
|
||||||
from image_prediction.utils.generic import lift, starlift
|
from image_prediction.utils.generic import lift, starlift
|
||||||
|
|
||||||
@ -37,20 +43,21 @@ class Pipeline:
|
|||||||
extract = get_extractor(**kwargs)
|
extract = get_extractor(**kwargs)
|
||||||
classifier = get_image_classifier(model_loader, model_identifier)
|
classifier = get_image_classifier(model_loader, model_identifier)
|
||||||
reformat = get_formatter()
|
reformat = get_formatter()
|
||||||
|
represent = get_encoder()
|
||||||
|
|
||||||
split = compose(star(parallel(*map(lift, (first, second)))), tee)
|
split = compose(star(parallel(*map(lift, (first, first, second)))), rpartial(tee, 3))
|
||||||
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
|
classify = compose(chain.from_iterable, lift(classifier), partial(chunks, batch_size))
|
||||||
pairwise_apply = compose(star, parallel)
|
pairwise_apply = compose(star, parallel)
|
||||||
join = compose(starlift(lambda prd, mdt: {"classification": prd, **mdt}), star(zip))
|
join = compose(starlift(lambda prd, rpr, mdt: {"classification": prd, **mdt, "representation": rpr}), star(zip))
|
||||||
|
|
||||||
# +>--classify--v
|
# />--classify--\
|
||||||
# --extract-->--split--| |--join-->reformat
|
# --extract-->--split--+->--encode---->+--join-->reformat
|
||||||
# +>--identity--^
|
# \>--identity--/
|
||||||
|
|
||||||
self.pipe = rcompose(
|
self.pipe = rcompose(
|
||||||
extract, # ... image-metadata-pairs as a stream
|
extract, # ... image-metadata-pairs as a stream
|
||||||
split, # ... into an image stream and a metadata stream
|
split, # ... into an image stream and a metadata stream
|
||||||
pairwise_apply(classify, identity), # ... apply functions to the streams pairwise
|
pairwise_apply(classify, represent, identity), # ... apply functions to the streams pairwise
|
||||||
join, # ... the streams by zipping
|
join, # ... the streams by zipping
|
||||||
reformat, # ... the items
|
reformat, # ... the items
|
||||||
)
|
)
|
||||||
|
|||||||
@ -36,11 +36,13 @@ def build_image_info(data: dict) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
classification = data["classification"]
|
classification = data["classification"]
|
||||||
|
representation = data["representation"]
|
||||||
|
|
||||||
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence)
|
||||||
|
|
||||||
image_info = {
|
image_info = {
|
||||||
"classification": classification,
|
"classification": classification,
|
||||||
|
"representation": representation,
|
||||||
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1},
|
"position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1},
|
||||||
"geometry": {"width": width, "height": height},
|
"geometry": {"width": width, "height": height},
|
||||||
"alpha": alpha,
|
"alpha": alpha,
|
||||||
|
|||||||
@ -17,6 +17,7 @@ pytest_plugins = [
|
|||||||
"test.fixtures.parameters",
|
"test.fixtures.parameters",
|
||||||
"test.fixtures.pdf",
|
"test.fixtures.pdf",
|
||||||
"test.fixtures.target",
|
"test.fixtures.target",
|
||||||
|
"test.unit_tests.image_stitching_test"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
"signature": 0.0
|
"signature": 0.0
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"representation": "FFFEF0C7033648170F3EFFFFF",
|
||||||
"position": {
|
"position": {
|
||||||
"x1": 321,
|
"x1": 321,
|
||||||
"x2": 515,
|
"x2": 515,
|
||||||
|
|||||||
10
test/fixtures/target.py
vendored
10
test/fixtures/target.py
vendored
@ -6,8 +6,9 @@ from operator import itemgetter
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from funcy import rcompose
|
from funcy import rcompose, lmap
|
||||||
|
|
||||||
|
from image_prediction.encoder.encoders.hash_encoder import hash_image
|
||||||
from image_prediction.exceptions import UnknownLabelFormat
|
from image_prediction.exceptions import UnknownLabelFormat
|
||||||
from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys
|
from image_prediction.label_mapper.mappers.probability import ProbabilityMapperKeys
|
||||||
from image_prediction.locations import TEST_DATA_DIR
|
from image_prediction.locations import TEST_DATA_DIR
|
||||||
@ -54,7 +55,7 @@ def batch_of_expected_label_to_probability_mappings(batch_of_expected_probabilit
|
|||||||
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
|
return {ProbabilityMapperKeys.LABEL: most_likely, ProbabilityMapperKeys.PROBABILITIES: lbl2prob}
|
||||||
|
|
||||||
rounder = rcompose(partial(np.round, decimals=4), float)
|
rounder = rcompose(partial(np.round, decimals=4), float)
|
||||||
return list(map(map_probabilities, batch_of_expected_probability_arrays))
|
return lmap(map_probabilities, batch_of_expected_probability_arrays)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -89,3 +90,8 @@ def expected_predictions_mapped_and_formatted(expected_predictions_mapped):
|
|||||||
def real_expected_service_response():
|
def real_expected_service_response():
|
||||||
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
|
with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f:
|
||||||
yield json.load(f)
|
yield json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def hashed_images(images):
|
||||||
|
return lmap(hash_image, images)
|
||||||
|
|||||||
@ -77,7 +77,6 @@ def server_process(server, host_and_port, url):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("server_type", ["actual"])
|
@pytest.mark.parametrize("server_type", ["actual"])
|
||||||
@pytest.mark.skip()
|
|
||||||
def test_server_predict(url, real_pdf, real_expected_service_response):
|
def test_server_predict(url, real_pdf, real_expected_service_response):
|
||||||
response = requests.post(f"{url}/predict", data=real_pdf)
|
response = requests.post(f"{url}/predict", data=real_pdf)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|||||||
0
test/unit_tests/__init__.py
Normal file
0
test/unit_tests/__init__.py
Normal file
31
test/unit_tests/encoder_test.py
Normal file
31
test/unit_tests/encoder_test.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import random
|
||||||
|
from itertools import starmap
|
||||||
|
from operator import __eq__
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from PIL.Image import Image
|
||||||
|
from funcy import compose, first
|
||||||
|
|
||||||
|
from image_prediction.encoder.encoders.hash_encoder import HashEncoder, hash_image
|
||||||
|
from image_prediction.utils.generic import lift
|
||||||
|
|
||||||
|
|
||||||
|
def resize(image: Image):
|
||||||
|
factor = random.uniform(0.3, 2)
|
||||||
|
new_size = map(lambda x: int(x * factor), image.size)
|
||||||
|
return image.resize(new_size)
|
||||||
|
|
||||||
|
|
||||||
|
def close(a: str, b: str):
|
||||||
|
assert len(a) == len(b)
|
||||||
|
return sum(starmap(__eq__, zip(a, b))) / len(a) >= 0.75
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Stochastic test, may fail some amount of the time.")
|
||||||
|
def test_hash_encoder(images, hashed_images, base_patch_image):
|
||||||
|
encoder = HashEncoder()
|
||||||
|
assert list(encoder(images)) == hashed_images
|
||||||
|
|
||||||
|
hashed_resized = compose(first, encoder, lift(resize))([base_patch_image])
|
||||||
|
hashed = hash_image(base_patch_image)
|
||||||
|
assert close(hashed_resized, hashed)
|
||||||
Loading…
x
Reference in New Issue
Block a user