diff --git a/image_prediction/transformer/transformers/response.py b/image_prediction/transformer/transformers/response.py index 7c2b326..378fe7b 100644 --- a/image_prediction/transformer/transformers/response.py +++ b/image_prediction/transformer/transformers/response.py @@ -20,45 +20,6 @@ class ResponseTransformer(Transformer): return build_image_info(data) -def get_class_specific_min_image_to_page_quotient(label, table=None): - return get_class_specific_value( - "REL_IMAGE_SIZE", label, "min", CONFIG.filters.image_to_page_quotient.min, table=table - ) - - -def get_class_specific_max_image_to_page_quotient(label, table=None): - return get_class_specific_value( - "REL_IMAGE_SIZE", label, "max", CONFIG.filters.image_to_page_quotient.max, table=table - ) - - -def get_class_specific_min_image_width_to_height_quotient(label, table=None): - return get_class_specific_value( - "IMAGE_FORMAT", label, "min", CONFIG.filters.image_width_to_height_quotient.min, table=table - ) - - -def get_class_specific_max_image_width_to_height_quotient(label, table=None): - return get_class_specific_value( - "IMAGE_FORMAT", label, "max", CONFIG.filters.image_width_to_height_quotient.max, table=table - ) - - -def get_class_specific_min_classification_confidence(label, table=None): - return get_class_specific_value("CONFIDENCE", label, "min", CONFIG.filters.min_confidence, table=table) - - -def get_class_specific_value(prefix, label, bound, fallback_value, table=None): - def fallback(): - logger.warning(f"Failed to resolve {bound} {prefix.lower().replace('_', '-')} value for class '{label}'.") - return fallback_value - - assert bound in ["min", "max"] - - threshold_map = parse_env_var(prefix, table=table) or {} - return threshold_map.get(label, {}).get(bound) or fallback() - - def build_image_info(data: dict) -> dict: def compute_geometric_quotient(): page_area_sqrt = math.sqrt(abs(page_width * page_height)) @@ -128,6 +89,50 @@ def build_image_info(data: dict) -> dict: return image_info +def get_class_specific_min_image_to_page_quotient(label, table=None): + return get_class_specific_value( + "REL_IMAGE_SIZE", label, "min", CONFIG.filters.image_to_page_quotient.min, table=table + ) + + +def get_class_specific_max_image_to_page_quotient(label, table=None): + return get_class_specific_value( + "REL_IMAGE_SIZE", label, "max", CONFIG.filters.image_to_page_quotient.max, table=table + ) + + +def get_class_specific_min_image_width_to_height_quotient(label, table=None): + return get_class_specific_value( + "IMAGE_FORMAT", label, "min", CONFIG.filters.image_width_to_height_quotient.min, table=table + ) + + +def get_class_specific_max_image_width_to_height_quotient(label, table=None): + return get_class_specific_value( + "IMAGE_FORMAT", label, "max", CONFIG.filters.image_width_to_height_quotient.max, table=table + ) + + +def get_class_specific_min_classification_confidence(label, table=None): + return get_class_specific_value("CONFIDENCE", label, "min", CONFIG.filters.min_confidence, table=table) + + +def get_class_specific_value(prefix, label, bound, fallback_value, table=None): + def fallback(): + return fallback_value + + def success(): + threshold_map = parse_env_var(prefix, table=table) or {} + value = threshold_map.get(label, {}).get(bound) + if value: + logger.debug(f"Using class '{label}' specific {bound} {prefix.lower().replace('_', '-')} value.") + return value + + assert bound in ["min", "max"] + + return success() or fallback() + + @lru_cache(maxsize=None) def parse_env_var(prefix, table=None): table = table or os.environ diff --git a/test/unit_tests/response_transformer_test.py b/test/unit_tests/response_transformer_test.py index ac8d822..c93e1af 100644 --- a/test/unit_tests/response_transformer_test.py +++ b/test/unit_tests/response_transformer_test.py @@ -21,8 +21,8 @@ def label(): def page_quotient_threshold_map(label): return frozendict( { - "REL_IMAGE_SIZE_MAP": json.dumps({label: {"min": 0.1, "max": 0.2}}), - "IMAGE_FORMAT_MAP": json.dumps({label: {"min": 0.5, "max": 0.4}}), + "REL_IMAGE_SIZE": json.dumps({label: {"min": 0.1, "max": 0.2}}), + "IMAGE_FORMAT": json.dumps({label: {"min": 0.5, "max": 0.4}}), "CONFIDENCE": json.dumps({label: {"min": 0.8}}), } )