chore(prediction filters): adapt class specific filter logic
This commit is contained in:
parent
a024ddfcf7
commit
150d0d64e5
@ -11,9 +11,9 @@ mlflow_run_id = "fabfb1f192c745369b88cab34471aba7"
|
||||
# The filter result values are reported in the service responses. For convenience the response to a request contains a
|
||||
# "filters.allPassed" field, which is set to false if any of the values returned by the filters did not meet its
|
||||
# specified required value.
|
||||
[filters]
|
||||
[filters.confidence]
|
||||
# Minimum permissible prediction confidence
|
||||
min_confidence = 0.5
|
||||
min = 0.5
|
||||
|
||||
# Image size to page size ratio (ratio of geometric means of areas)
|
||||
[filters.image_to_page_quotient]
|
||||
@ -25,4 +25,8 @@ max = 0.75
|
||||
min = 0.1
|
||||
max = 10
|
||||
|
||||
# put class specific filters here ['signature', 'formula', 'logo']
|
||||
[filters.overrides.signature.image_to_page_quotient]
|
||||
max = 0.4
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "image-classification-service"
|
||||
version = "2.0.0"
|
||||
version = "2.2.0"
|
||||
description = ""
|
||||
authors = ["Team Research <research@knecon.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
@ -1,13 +1,8 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from dynaconf import Dynaconf
|
||||
from operator import itemgetter
|
||||
|
||||
from funcy import first
|
||||
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.exceptions import ParsingError
|
||||
from image_prediction.transformer.transformer import Transformer
|
||||
from image_prediction.utils import get_logger
|
||||
|
||||
@ -32,21 +27,22 @@ def build_image_info(data: dict) -> dict:
|
||||
geometric_quotient = round(compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1), 4)
|
||||
|
||||
min_image_to_page_quotient_breached = bool(
|
||||
geometric_quotient < get_class_specific_min_image_to_page_quotient(label)
|
||||
geometric_quotient < get_class_specific_filter_value(label, CONFIG, "image_to_page_quotient", "min")
|
||||
)
|
||||
max_image_to_page_quotient_breached = bool(
|
||||
geometric_quotient > get_class_specific_max_image_to_page_quotient(label)
|
||||
geometric_quotient > get_class_specific_filter_value(label, CONFIG, "image_to_page_quotient", "max")
|
||||
)
|
||||
|
||||
min_image_width_to_height_quotient_breached = bool(
|
||||
width / height < get_class_specific_min_image_width_to_height_quotient(label)
|
||||
width / height < get_class_specific_filter_value(label, CONFIG, "image_width_to_height_quotient", "min")
|
||||
)
|
||||
max_image_width_to_height_quotient_breached = bool(
|
||||
width / height > get_class_specific_max_image_width_to_height_quotient(label)
|
||||
width / height > get_class_specific_filter_value(label, CONFIG, "image_width_to_height_quotient", "max")
|
||||
)
|
||||
|
||||
min_confidence_breached = bool(
|
||||
max(classification["probabilities"].values()) < get_class_specific_min_classification_confidence(label)
|
||||
max(classification["probabilities"].values())
|
||||
< get_class_specific_filter_value(label, CONFIG, "confidence", "min")
|
||||
)
|
||||
|
||||
image_info = {
|
||||
@ -90,65 +86,15 @@ def compute_geometric_quotient(page_width, page_height, x2, x1, y2, y1):
|
||||
return image_area_sqrt / page_area_sqrt
|
||||
|
||||
|
||||
def get_class_specific_min_image_to_page_quotient(label, table=None):
|
||||
return get_class_specific_value(
|
||||
"REL_IMAGE_SIZE", label, "min", CONFIG.filters.image_to_page_quotient.min, table=table
|
||||
)
|
||||
|
||||
|
||||
def get_class_specific_max_image_to_page_quotient(label, table=None):
|
||||
return get_class_specific_value(
|
||||
"REL_IMAGE_SIZE", label, "max", CONFIG.filters.image_to_page_quotient.max, table=table
|
||||
)
|
||||
|
||||
|
||||
def get_class_specific_min_image_width_to_height_quotient(label, table=None):
|
||||
return get_class_specific_value(
|
||||
"IMAGE_FORMAT", label, "min", CONFIG.filters.image_width_to_height_quotient.min, table=table
|
||||
)
|
||||
|
||||
|
||||
def get_class_specific_max_image_width_to_height_quotient(label, table=None):
|
||||
return get_class_specific_value(
|
||||
"IMAGE_FORMAT", label, "max", CONFIG.filters.image_width_to_height_quotient.max, table=table
|
||||
)
|
||||
|
||||
|
||||
def get_class_specific_min_classification_confidence(label, table=None):
|
||||
return get_class_specific_value("CONFIDENCE", label, "min", CONFIG.filters.min_confidence, table=table)
|
||||
|
||||
|
||||
def get_class_specific_value(prefix, label, bound, fallback_value, table=None):
|
||||
def fallback():
|
||||
return fallback_value
|
||||
|
||||
def success():
|
||||
threshold_map = parse_env_var(prefix, table=table) or {}
|
||||
value = threshold_map.get(label, {}).get(bound)
|
||||
if value:
|
||||
logger.debug(f"Using class '{label}' specific {bound} {prefix.lower().replace('_', '-')} value.")
|
||||
return value
|
||||
|
||||
assert bound in ["min", "max"]
|
||||
|
||||
return success() or fallback()
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def parse_env_var(prefix, table=None):
|
||||
table = table or os.environ
|
||||
head = first(filter(lambda s: s == prefix, table))
|
||||
if head:
|
||||
try:
|
||||
return parse_env_var_value(table[head])
|
||||
except ParsingError as err:
|
||||
logger.warning(err)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def parse_env_var_value(env_var_value):
|
||||
def get_class_specific_filter_value(label: str, settings: Dynaconf, filter_type: str, bound: str = None):
|
||||
try:
|
||||
return json.loads(env_var_value)
|
||||
except Exception as err:
|
||||
raise ParsingError(f"Failed to parse {env_var_value}") from err
|
||||
value = (
|
||||
settings.filters.overrides[label][filter_type][bound]
|
||||
if bound
|
||||
else settings.filters.overrides[label][filter_type]
|
||||
)
|
||||
logger.warning(f"Using {label=} specific {bound=} {filter_type=} {value=}.")
|
||||
except KeyError:
|
||||
value = settings.filters[filter_type][bound]
|
||||
|
||||
return value
|
||||
|
||||
@ -1,15 +1,7 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from frozendict import frozendict
|
||||
|
||||
from image_prediction.transformer.transformers.response import (
|
||||
get_class_specific_min_image_to_page_quotient,
|
||||
get_class_specific_max_image_to_page_quotient,
|
||||
get_class_specific_max_image_width_to_height_quotient,
|
||||
get_class_specific_min_image_width_to_height_quotient,
|
||||
get_class_specific_min_classification_confidence,
|
||||
)
|
||||
from image_prediction.config import CONFIG
|
||||
from image_prediction.transformer.transformers.response import get_class_specific_filter_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -17,20 +9,9 @@ def label():
|
||||
return "signature"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def page_quotient_threshold_map(label):
|
||||
return frozendict(
|
||||
{
|
||||
"REL_IMAGE_SIZE": json.dumps({label: {"min": 0.1, "max": 0.2}}),
|
||||
"IMAGE_FORMAT": json.dumps({label: {"min": 0.5, "max": 0.4}}),
|
||||
"CONFIDENCE": json.dumps({label: {"min": 0.8}}),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_read_environment_vars_for_thresholds(page_quotient_threshold_map, label):
|
||||
assert get_class_specific_min_image_to_page_quotient(label, table=page_quotient_threshold_map) == 0.1
|
||||
assert get_class_specific_max_image_to_page_quotient(label, table=page_quotient_threshold_map) == 0.2
|
||||
assert get_class_specific_min_image_width_to_height_quotient(label, table=page_quotient_threshold_map) == 0.5
|
||||
assert get_class_specific_max_image_width_to_height_quotient(label, table=page_quotient_threshold_map) == 0.4
|
||||
assert get_class_specific_min_classification_confidence(label, table=page_quotient_threshold_map) == 0.8
|
||||
def test_read_environment_vars_for_thresholds(label):
|
||||
assert get_class_specific_filter_value(label, CONFIG, "image_to_page_quotient", "min") == 0.05
|
||||
assert get_class_specific_filter_value(label, CONFIG, "image_to_page_quotient", "max") == 0.4
|
||||
assert get_class_specific_filter_value(label, CONFIG, "image_width_to_height_quotient", "min") == 0.1
|
||||
assert get_class_specific_filter_value(label, CONFIG, "image_width_to_height_quotient", "max") == 10
|
||||
assert get_class_specific_filter_value(label, CONFIG, "confidence", "min") == 0.5
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user