import json import math import os from functools import lru_cache from operator import itemgetter from funcy import first from image_prediction.config import CONFIG from image_prediction.exceptions import ParsingError from image_prediction.transformer.transformer import Transformer from image_prediction.utils import get_logger logger = get_logger() class ResponseTransformer(Transformer): def transform(self, data): logger.debug("ResponseTransformer.transform") return build_image_info(data) def get_class_specific_min_image_to_page_quotient(label, table=None): return get_class_specific_value( "REL_IMAGE_SIZE", label, "min", CONFIG.filters.image_to_page_quotient.min, table=table ) def get_class_specific_max_image_to_page_quotient(label, table=None): return get_class_specific_value( "REL_IMAGE_SIZE", label, "max", CONFIG.filters.image_to_page_quotient.max, table=table ) def get_class_specific_min_image_width_to_height_quotient(label, table=None): return get_class_specific_value( "IMAGE_FORMAT", label, "min", CONFIG.filters.image_width_to_height_quotient.min, table=table ) def get_class_specific_max_image_width_to_height_quotient(label, table=None): return get_class_specific_value( "IMAGE_FORMAT", label, "max", CONFIG.filters.image_width_to_height_quotient.max, table=table ) def get_class_specific_min_classification_confidence(label, table=None): return get_class_specific_value("CONFIDENCE", label, "min", CONFIG.filters.min_confidence, table=table) def get_class_specific_value(prefix, label, bound, fallback_value, table=None): def fallback(): logger.warning(f"Failed to resolve {bound} {prefix.lower().replace('_', '-')} value for class '{label}'.") return fallback_value assert bound in ["min", "max"] threshold_map = parse_env_var(prefix, table=table) or {} return threshold_map.get(label, {}).get(bound) or fallback() def build_image_info(data: dict) -> dict: def compute_geometric_quotient(): page_area_sqrt = math.sqrt(abs(page_width * page_height)) image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1)) return image_area_sqrt / page_area_sqrt page_width, page_height, x1, x2, y1, y2, width, height, alpha = itemgetter( "page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height", "alpha" )(data) classification = data["classification"] label = classification["label"] representation = data["representation"] geometric_quotient = round(compute_geometric_quotient(), 4) min_image_to_page_quotient_breached = bool( geometric_quotient < get_class_specific_min_image_to_page_quotient(label) ) max_image_to_page_quotient_breached = bool( geometric_quotient > get_class_specific_max_image_to_page_quotient(label) ) min_image_width_to_height_quotient_breached = bool( width / height < get_class_specific_min_image_width_to_height_quotient(label) ) max_image_width_to_height_quotient_breached = bool( width / height > get_class_specific_max_image_width_to_height_quotient(label) ) min_confidence_breached = bool( max(classification["probabilities"].values()) < get_class_specific_min_classification_confidence(label) ) image_info = { "classification": classification, "representation": representation, "position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1}, "geometry": {"width": width, "height": height}, "alpha": alpha, "filters": { "geometry": { "imageSize": { "quotient": geometric_quotient, "tooLarge": max_image_to_page_quotient_breached, "tooSmall": min_image_to_page_quotient_breached, }, "imageFormat": { "quotient": round(width / height, 4), "tooTall": min_image_width_to_height_quotient_breached, "tooWide": max_image_width_to_height_quotient_breached, }, }, "probability": {"unconfident": min_confidence_breached}, "allPassed": not any( [ max_image_to_page_quotient_breached, min_image_to_page_quotient_breached, min_image_width_to_height_quotient_breached, max_image_width_to_height_quotient_breached, min_confidence_breached, ] ), }, } return image_info @lru_cache(maxsize=None) def parse_env_var(prefix, table=None): table = table or os.environ head = first(filter(lambda s: s == prefix, table)) if head: try: return parse_env_var_value(table[head]) except ParsingError as err: logger.warning(err) else: return None def parse_env_var_value(env_var_value): try: return json.loads(env_var_value) except Exception as err: raise ParsingError(f"Failed to parse {env_var_value}") from err