diff --git a/config.yaml b/config.yaml index e2fd18b..6a6111a 100644 --- a/config.yaml +++ b/config.yaml @@ -18,10 +18,6 @@ filters: image_to_page_quotient: # Image size to page size ratio (ratio of geometric means of areas) min: $MIN_REL_IMAGE_SIZE|0.05 # Minimum permissible max: $MAX_REL_IMAGE_SIZE|0.75 # Maximum permissible - # Fixme: Temporary solution, delete - customized: # Customized settings per class (RED-5202) - max: - signature: $MAX_REL_SIGNATURE_SIZE|0.4 image_width_to_height_quotient: # Image width to height ratio min: $MIN_IMAGE_FORMAT|0.1 # Minimum permissible diff --git a/image_prediction/transformer/transformers/response.py b/image_prediction/transformer/transformers/response.py index dc647f8..a304603 100644 --- a/image_prediction/transformer/transformers/response.py +++ b/image_prediction/transformer/transformers/response.py @@ -4,11 +4,12 @@ import os from functools import lru_cache from operator import itemgetter +from funcy import filter, juxt, first, rest, compose + from image_prediction.config import CONFIG from image_prediction.exceptions import ParsingError from image_prediction.transformer.transformer import Transformer from image_prediction.utils import get_logger -from funcy import filter, juxt, first, rest logger = get_logger() @@ -19,6 +20,37 @@ class ResponseTransformer(Transformer): return build_image_info(data) +def get_class_specific_min_image_to_page_quotient(label): + return get_class_specific_quotient("REL_IMAGE_SIZE", label, "min", CONFIG.filters.image_to_page_quotient.min) + + +def get_class_specific_max_image_to_page_quotient(label): + return get_class_specific_quotient("REL_IMAGE_SIZE", label, "max", CONFIG.filters.image_to_page_quotient.max) + + +def get_class_specific_min_image_width_to_height_quotient(label): + return get_class_specific_quotient("IMAGE_FORMAT", label, "min", CONFIG.filters.image_width_to_height_quotient.min) + + +def get_class_specific_max_image_width_to_height_quotient(label): + return get_class_specific_quotient("IMAGE_FORMAT", label, "max", CONFIG.filters.image_width_to_height_quotient.max) + + +def get_class_specific_min_classification_confidence(label): + return get_class_specific_quotient("CONFIDENCE", label, "min", CONFIG.filters.min_confidence) + + +def get_class_specific_quotient(prefix, label, bound, fallback_value): + def fallback(): + logger.warning(f"Failed to resolve {bound} {prefix.lower().replace('_', '-')} value for class '{label}'.") + return fallback_value + + assert bound in ["min", "max"] + + threshold_map = parse_env_var(prefix) or {} + return threshold_map.get(label, {}).get(bound) or fallback() + + def build_image_info(data: dict) -> dict: def compute_geometric_quotient(): page_area_sqrt = math.sqrt(abs(page_width * page_height)) @@ -29,23 +61,29 @@ def build_image_info(data: dict) -> dict: "page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height", "alpha" )(data) - quotient = round(compute_geometric_quotient(), 4) - - min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min) - max_image_to_page_quotient_breached = __is_max_image_to_page_quotient_breached( - quotient, data["classification"]["label"] - ) - min_image_width_to_height_quotient_breached = bool( - width / height < CONFIG.filters.image_width_to_height_quotient.min - ) - max_image_width_to_height_quotient_breached = bool( - width / height > CONFIG.filters.image_width_to_height_quotient.max - ) - classification = data["classification"] + label = classification["label"] representation = data["representation"] - min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) + geometric_quotient = round(compute_geometric_quotient(), 4) + + min_image_to_page_quotient_breached = bool( + geometric_quotient < get_class_specific_min_image_to_page_quotient(label) + ) + max_image_to_page_quotient_breached = bool( + geometric_quotient > get_class_specific_max_image_to_page_quotient(label) + ) + + min_image_width_to_height_quotient_breached = bool( + width / height < get_class_specific_min_image_width_to_height_quotient(label) + ) + max_image_width_to_height_quotient_breached = bool( + width / height > get_class_specific_max_image_width_to_height_quotient(label) + ) + + min_confidence_breached = bool( + max(classification["probabilities"].values()) < get_class_specific_min_classification_confidence(label) + ) image_info = { "classification": classification, @@ -56,7 +94,7 @@ def build_image_info(data: dict) -> dict: "filters": { "geometry": { "imageSize": { - "quotient": quotient, + "quotient": geometric_quotient, "tooLarge": max_image_to_page_quotient_breached, "tooSmall": min_image_to_page_quotient_breached, }, @@ -82,28 +120,18 @@ def build_image_info(data: dict) -> dict: return image_info -def __is_max_image_to_page_quotient_breached(quotient: float, label: str) -> bool: - default_max_quotient = CONFIG.filters.image_to_page_quotient.max - customized_entries = CONFIG.filters.image_to_page_quotient.customized.max - max_quotient = customized_entries.get(label, default_max_quotient) - max_quotient = max_quotient if max_quotient else default_max_quotient - return bool(quotient > max_quotient) - - @lru_cache(maxsize=None) -def parse_env_var(prefix, fallback_value): - head, tail = juxt(first, rest)(filter(prefix, os.environ)) - if not head or tail: - logger.warning( - f"Found multiple candidates for environment variable with prefix '{prefix}', falling back to default value." - ) - return fallback_value +def parse_env_var(prefix): + head, tail = juxt(first, compose(list, rest))(filter(prefix, os.environ)) + if not head: + logger.warning(f"Found no environment variable with prefix '{prefix}'.") + elif tail: + logger.warning(f"Found multiple candidates for environment variable with prefix '{prefix}'.") else: try: return parse_env_var_value(os.environ[head]) except ParsingError as err: - logger.warning(f"{err}, falling back to default value.") - return fallback_value + logger.warning(err) def parse_env_var_value(env_var_value): diff --git a/test/unit_tests/response_transformer_test.py b/test/unit_tests/response_transformer_test.py index d8c58a7..fde12b8 100644 --- a/test/unit_tests/response_transformer_test.py +++ b/test/unit_tests/response_transformer_test.py @@ -1,21 +1,37 @@ +import json +import os + import pytest -from image_prediction.transformer.transformers.response import __is_max_image_to_page_quotient_breached +from image_prediction.transformer.transformers.response import ( + get_class_specific_min_image_to_page_quotient, + get_class_specific_max_image_to_page_quotient, + get_class_specific_max_image_width_to_height_quotient, + get_class_specific_min_image_width_to_height_quotient, + get_class_specific_min_classification_confidence, +) @pytest.fixture -def expected_is_breached(quotient, label): - if label == "signature" and quotient < 0.4: - return False - elif label == "signature" and quotient >= 0.4: - return True - elif quotient < 0.7: - return False - else: - return True +def label(): + return "signature" -@pytest.mark.parametrize("quotient", [0.1, 0.5]) -@pytest.mark.parametrize("label", ["logo", "signature"]) -def test_customized_per_label_ratio_breach(quotient, label, expected_is_breached): - assert __is_max_image_to_page_quotient_breached(quotient, label) == expected_is_breached +@pytest.fixture +def page_quotient_threshold_map(label): + # TODO: suboptimal, as actual environment is used + os.environ["REL_IMAGE_SIZE_MAP"] = json.dumps({label: {"min": 0.1, "max": 0.2}}) + os.environ["IMAGE_FORMAT_MAP"] = json.dumps({label: {"min": 0.5, "max": 0.4}}) + os.environ["CONFIDENCE"] = json.dumps({label: {"min": 0.8}}) + yield + for env_var in ("REL_IMAGE_SIZE_MAP", "IMAGE_FORMAT_MAP", "CONFIDENCE"): + os.environ.pop(env_var) + + +# FIXME: Runs correctly in isolation, but fails when other tests are run before +def test_read_environment_vars_for_thresholds(page_quotient_threshold_map, label): + assert get_class_specific_min_image_to_page_quotient(label) == 0.1 + assert get_class_specific_max_image_to_page_quotient(label) == 0.2 + assert get_class_specific_min_image_width_to_height_quotient(label) == 0.5 + assert get_class_specific_max_image_width_to_height_quotient(label) == 0.4 + assert get_class_specific_min_classification_confidence(label) == 0.8