Compare commits

...

8 Commits

6 changed files with 165 additions and 15 deletions

View File

@ -15,6 +15,12 @@ class DotIndexable:
def __init__(self, x): def __init__(self, x):
self.x = x self.x = x
def get(self, item, default=None):
try:
return _get_item_and_maybe_make_dotindexable(self.x, item)
except KeyError:
return default
def __getattr__(self, item): def __getattr__(self, item):
return _get_item_and_maybe_make_dotindexable(self.x, item) return _get_item_and_maybe_make_dotindexable(self.x, item)

View File

@ -32,3 +32,7 @@ class IntentionalTestException(RuntimeError):
class InvalidBox(Exception): class InvalidBox(Exception):
pass pass
class ParsingError(Exception):
pass

View File

@ -1,14 +1,17 @@
import atexit import atexit
import io import io
import json
import traceback
from functools import partial, lru_cache from functools import partial, lru_cache
from itertools import chain, starmap, filterfalse from itertools import chain, starmap, filterfalse
from operator import itemgetter from operator import itemgetter, truth
from typing import List from typing import List, Iterable, Iterator
import fitz import fitz
from PIL import Image from PIL import Image
from funcy import rcompose, merge, pluck, curry, compose from funcy import rcompose, merge, pluck, curry, compose
from image_prediction.formatter.formatters.enum import EnumFormatter
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.info import Info from image_prediction.info import Info
from image_prediction.stitching.stitching import stitch_pairs from image_prediction.stitching.stitching import stitch_pairs
@ -47,10 +50,28 @@ class ParsablePDFImageExtractor(ImageExtractor):
clear_caches() clear_caches()
image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata))) image_metadata_pairs = starmap(ImageMetadataPair, filter(all, zip(images, metadata)))
# TODO: In the future, consider to introduce an image validator as a pipeline component rather than doing the
# validation here. Invalid images can then be split into a different stream and joined with the intact images
# again for the formatting step.
image_metadata_pairs = self.__filter_valid_images(image_metadata_pairs)
image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance) image_metadata_pairs = stitch_pairs(list(image_metadata_pairs), tolerance=self.tolerance)
yield from image_metadata_pairs yield from image_metadata_pairs
@staticmethod
def __filter_valid_images(image_metadata_pairs: Iterable[ImageMetadataPair]) -> Iterator[ImageMetadataPair]:
def validate(image: Image.Image, metadata: dict):
try:
# TODO: stand-in heuristic for testing if image is valid => find cleaner solution (RED-5148)
image.resize((100, 100)).convert("RGB")
return ImageMetadataPair(image, metadata)
except (OSError, Exception) as err:
metadata = json.dumps(EnumFormatter()(metadata), indent=2)
logger.warning(f"Invalid image encountered. Image metadata:\n{metadata}\n\n{traceback.format_exc()}")
return None
return filter(truth, starmap(validate, image_metadata_pairs))
def extract_pages(doc, page_range): def extract_pages(doc, page_range):
page_range = range(page_range.start + 1, page_range.stop + 1) page_range = range(page_range.start + 1, page_range.stop + 1)

View File

@ -1,7 +1,13 @@
import json
import math import math
import os
from functools import lru_cache
from operator import itemgetter from operator import itemgetter
from funcy import juxt, first, rest, compose
from image_prediction.config import CONFIG from image_prediction.config import CONFIG
from image_prediction.exceptions import ParsingError
from image_prediction.transformer.transformer import Transformer from image_prediction.transformer.transformer import Transformer
from image_prediction.utils import get_logger from image_prediction.utils import get_logger
@ -14,6 +20,45 @@ class ResponseTransformer(Transformer):
return build_image_info(data) return build_image_info(data)
def get_class_specific_min_image_to_page_quotient(label, table=None):
return get_class_specific_value(
"REL_IMAGE_SIZE", label, "min", CONFIG.filters.image_to_page_quotient.min, table=table
)
def get_class_specific_max_image_to_page_quotient(label, table=None):
return get_class_specific_value(
"REL_IMAGE_SIZE", label, "max", CONFIG.filters.image_to_page_quotient.max, table=table
)
def get_class_specific_min_image_width_to_height_quotient(label, table=None):
return get_class_specific_value(
"IMAGE_FORMAT", label, "min", CONFIG.filters.image_width_to_height_quotient.min, table=table
)
def get_class_specific_max_image_width_to_height_quotient(label, table=None):
return get_class_specific_value(
"IMAGE_FORMAT", label, "max", CONFIG.filters.image_width_to_height_quotient.max, table=table
)
def get_class_specific_min_classification_confidence(label, table=None):
return get_class_specific_value("CONFIDENCE", label, "min", CONFIG.filters.min_confidence, table=table)
def get_class_specific_value(prefix, label, bound, fallback_value, table=None):
def fallback():
logger.warning(f"Failed to resolve {bound} {prefix.lower().replace('_', '-')} value for class '{label}'.")
return fallback_value
assert bound in ["min", "max"]
threshold_map = parse_env_var(prefix, table=table) or {}
return threshold_map.get(label, {}).get(bound) or fallback()
def build_image_info(data: dict) -> dict: def build_image_info(data: dict) -> dict:
def compute_geometric_quotient(): def compute_geometric_quotient():
page_area_sqrt = math.sqrt(abs(page_width * page_height)) page_area_sqrt = math.sqrt(abs(page_width * page_height))
@ -24,21 +69,29 @@ def build_image_info(data: dict) -> dict:
"page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height", "alpha" "page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height", "alpha"
)(data) )(data)
quotient = round(compute_geometric_quotient(), 4)
min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min)
max_image_to_page_quotient_breached = bool(quotient > CONFIG.filters.image_to_page_quotient.max)
min_image_width_to_height_quotient_breached = bool(
width / height < CONFIG.filters.image_width_to_height_quotient.min
)
max_image_width_to_height_quotient_breached = bool(
width / height > CONFIG.filters.image_width_to_height_quotient.max
)
classification = data["classification"] classification = data["classification"]
label = classification["label"]
representation = data["representation"] representation = data["representation"]
min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) geometric_quotient = round(compute_geometric_quotient(), 4)
min_image_to_page_quotient_breached = bool(
geometric_quotient < get_class_specific_min_image_to_page_quotient(label)
)
max_image_to_page_quotient_breached = bool(
geometric_quotient > get_class_specific_max_image_to_page_quotient(label)
)
min_image_width_to_height_quotient_breached = bool(
width / height < get_class_specific_min_image_width_to_height_quotient(label)
)
max_image_width_to_height_quotient_breached = bool(
width / height > get_class_specific_max_image_width_to_height_quotient(label)
)
min_confidence_breached = bool(
max(classification["probabilities"].values()) < get_class_specific_min_classification_confidence(label)
)
image_info = { image_info = {
"classification": classification, "classification": classification,
@ -49,7 +102,7 @@ def build_image_info(data: dict) -> dict:
"filters": { "filters": {
"geometry": { "geometry": {
"imageSize": { "imageSize": {
"quotient": quotient, "quotient": geometric_quotient,
"tooLarge": max_image_to_page_quotient_breached, "tooLarge": max_image_to_page_quotient_breached,
"tooSmall": min_image_to_page_quotient_breached, "tooSmall": min_image_to_page_quotient_breached,
}, },
@ -73,3 +126,23 @@ def build_image_info(data: dict) -> dict:
} }
return image_info return image_info
@lru_cache(maxsize=None)
def parse_env_var(prefix, table=None):
table = table or os.environ
head = first(filter(lambda s: s == prefix, table))
if not head:
logger.warning(f"Found no environment variable with prefix '{prefix}'.")
else:
try:
return parse_env_var_value(table[head])
except ParsingError as err:
logger.warning(err)
def parse_env_var_value(env_var_value):
try:
return json.loads(env_var_value)
except Exception as err:
raise ParsingError(f"Failed to parse {env_var_value}") from err

View File

@ -36,3 +36,13 @@ def test_dot_access_key_does_not_exists(config):
def test_access_key_does_not_exists(config): def test_access_key_does_not_exists(config):
assert config["B"] is None assert config["B"] is None
def test_get_method_returns_key_if_key_does_exist(config):
dot_indexable = config.D.E
assert dot_indexable.get("F", "default_value") is True
def test_get_method_returns_default_if_key_does_not_exist(config):
dot_indexable = config.D.E
assert dot_indexable.get("X", "default_value") == "default_value"

View File

@ -0,0 +1,36 @@
import json
import pytest
from frozendict import frozendict
from image_prediction.transformer.transformers.response import (
get_class_specific_min_image_to_page_quotient,
get_class_specific_max_image_to_page_quotient,
get_class_specific_max_image_width_to_height_quotient,
get_class_specific_min_image_width_to_height_quotient,
get_class_specific_min_classification_confidence,
)
@pytest.fixture
def label():
return "signature"
@pytest.fixture
def page_quotient_threshold_map(label):
return frozendict(
{
"REL_IMAGE_SIZE_MAP": json.dumps({label: {"min": 0.1, "max": 0.2}}),
"IMAGE_FORMAT_MAP": json.dumps({label: {"min": 0.5, "max": 0.4}}),
"CONFIDENCE": json.dumps({label: {"min": 0.8}}),
}
)
def test_read_environment_vars_for_thresholds(page_quotient_threshold_map, label):
assert get_class_specific_min_image_to_page_quotient(label, table=page_quotient_threshold_map) == 0.1
assert get_class_specific_max_image_to_page_quotient(label, table=page_quotient_threshold_map) == 0.2
assert get_class_specific_min_image_width_to_height_quotient(label, table=page_quotient_threshold_map) == 0.5
assert get_class_specific_max_image_width_to_height_quotient(label, table=page_quotient_threshold_map) == 0.4
assert get_class_specific_min_classification_confidence(label, table=page_quotient_threshold_map) == 0.8