From 5caa9807e2cefd4fd98ec22045351d3da80b4d7d Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 31 Mar 2022 19:01:32 +0200 Subject: [PATCH] added response formatter and pipeline test --- {image_prediction => deprecated}/response.py | 0 image_prediction/default_objects.py | 4 +- .../extractor_classifier.py | 2 +- .../formatter/formatters/response.py | 72 ++++++++++++++++++ image_prediction/locations.py | 2 + image_prediction/pipeline.py | 19 +++-- scripts/run_pipeline.py | 5 +- .../f2dc689ca794fccb8cd38b95f2bf6ba9.pdf | Bin ...ca794fccb8cd38b95f2bf6ba9_predictions.json | 42 ++++++++++ test/unit_tests/conftest.py | 4 +- test/unit_tests/extractor_classifier_test.py | 2 +- test/unit_tests/pipeline_test.py | 18 +++++ 12 files changed, 156 insertions(+), 14 deletions(-) rename {image_prediction => deprecated}/response.py (100%) create mode 100644 image_prediction/formatter/formatters/response.py rename test/{test_data => data}/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf (100%) create mode 100644 test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json create mode 100644 test/unit_tests/pipeline_test.py diff --git a/image_prediction/response.py b/deprecated/response.py similarity index 100% rename from image_prediction/response.py rename to deprecated/response.py diff --git a/image_prediction/default_objects.py b/image_prediction/default_objects.py index d968fef..d138d5e 100644 --- a/image_prediction/default_objects.py +++ b/image_prediction/default_objects.py @@ -11,10 +11,10 @@ def get_mlflow_model_loader(mlruns_dir): return model_loader -def load_pipeline(): +def load_pipeline(**kwargs): model_loader = get_mlflow_model_loader(MLRUNS_DIR) model_identifier = CONFIG.service.run_id - pipeline = Pipeline(model_loader, model_identifier) + pipeline = Pipeline(model_loader, model_identifier, **kwargs) return pipeline diff --git a/image_prediction/extractor_classifier/extractor_classifier.py b/image_prediction/extractor_classifier/extractor_classifier.py index 1206756..b08330a 100644 --- a/image_prediction/extractor_classifier/extractor_classifier.py +++ b/image_prediction/extractor_classifier/extractor_classifier.py @@ -18,7 +18,7 @@ class ExtractorClassifier: images, metadata = zip(*batch) predictions = self.classifier(images) - responses = ({"prediction": prd, **mdt} for prd, mdt in zip(predictions, metadata)) + responses = ({"classification": prd, **mdt} for prd, mdt in zip(predictions, metadata)) return responses def __call__(self, obj) -> Iterable[ImageMetadataPair]: diff --git a/image_prediction/formatter/formatters/response.py b/image_prediction/formatter/formatters/response.py new file mode 100644 index 0000000..b345101 --- /dev/null +++ b/image_prediction/formatter/formatters/response.py @@ -0,0 +1,72 @@ +import math +from operator import itemgetter + +from image_prediction.config import CONFIG +from image_prediction.transformer.transformer import Transformer + + +class ResponseTransformer(Transformer): + + def transform(self, data): + try: + return build_image_info(data) + except TypeError: + return map(build_image_info, data) + + +def build_image_info(data: dict) -> dict: + def compute_geometric_quotient(): + page_area_sqrt = math.sqrt(abs(page_width * page_height)) + image_area_sqrt = math.sqrt(abs(x2 - x1) * abs(y2 - y1)) + return image_area_sqrt / page_area_sqrt + + page_width, page_height, x1, x2, y1, y2, width, height = itemgetter( + "page_width", "page_height", "x1", "x2", "y1", "y2", "width", "height" + )(data) + + quotient = round(compute_geometric_quotient(), 4) + + min_image_to_page_quotient_breached = bool(quotient < CONFIG.filters.image_to_page_quotient.min) + max_image_to_page_quotient_breached = bool(quotient > CONFIG.filters.image_to_page_quotient.max) + min_image_width_to_height_quotient_breached = bool( + width / height < CONFIG.filters.image_width_to_height_quotient.min + ) + max_image_width_to_height_quotient_breached = bool( + width / height > CONFIG.filters.image_width_to_height_quotient.max + ) + + classification = data["classification"] + + min_confidence_breached = bool(max(classification["probabilities"].values()) < CONFIG.filters.min_confidence) + + image_info = { + "classification": classification, + "position": {"x1": x1, "x2": x2, "y1": y1, "y2": y2, "pageNumber": data["page_idx"] + 1}, + "geometry": {"width": width, "height": height}, + "filters": { + "geometry": { + "imageSize": { + "quotient": quotient, + "tooLarge": max_image_to_page_quotient_breached, + "tooSmall": min_image_to_page_quotient_breached, + }, + "imageFormat": { + "quotient": round(width / height, 4), + "tooTall": min_image_width_to_height_quotient_breached, + "tooWide": max_image_width_to_height_quotient_breached, + }, + }, + "probability": {"unconfident": min_confidence_breached}, + "allPassed": not any( + [ + max_image_to_page_quotient_breached, + min_image_to_page_quotient_breached, + min_image_width_to_height_quotient_breached, + max_image_width_to_height_quotient_breached, + min_confidence_breached, + ] + ), + }, + } + + return image_info diff --git a/image_prediction/locations.py b/image_prediction/locations.py index 9dfe2f4..8d62079 100644 --- a/image_prediction/locations.py +++ b/image_prediction/locations.py @@ -7,3 +7,5 @@ CONFIG_FILE = path.join(PACKAGE_ROOT_DIR, "config.yaml") DATA_DIR = path.join(PACKAGE_ROOT_DIR, "data") MLRUNS_DIR = path.join(DATA_DIR, "mlruns") + +TEST_DATA_DIR = path.join(PACKAGE_ROOT_DIR, "test", "data") diff --git a/image_prediction/pipeline.py b/image_prediction/pipeline.py index 3f964d3..2498dcf 100644 --- a/image_prediction/pipeline.py +++ b/image_prediction/pipeline.py @@ -4,9 +4,12 @@ from funcy import rcompose, juxt from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier +from image_prediction.compositor.compositor import TransformerCompositor from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.extractor_classifier.extractor_classifier import ExtractorClassifier +from image_prediction.formatter.formatters.camel_case import Snake2CamelCaseKeyFormatter from image_prediction.formatter.formatters.enum import EnumFormatter +from image_prediction.formatter.formatters.response import ResponseTransformer from image_prediction.image_extractor.extractors.parsable import ParsablePDFImageExtractor from image_prediction.label_mapper.mappers.probability import ProbabilityMapper @@ -18,27 +21,29 @@ def get_image_classifier(model_loader, model_identifier): return ImageClassifier(Classifier(EstimatorAdapter(model), ProbabilityMapper(classes))) -def get_extractor(): - image_extractor = ParsablePDFImageExtractor(verbose=True) +def get_extractor(**kwargs): + image_extractor = ParsablePDFImageExtractor(**kwargs) return image_extractor -def get_extractor_classifier(model_loader, model_identifier): - extractor_classifier = ExtractorClassifier(get_extractor(), get_image_classifier(model_loader, model_identifier)) +def get_extractor_classifier(model_loader, model_identifier, **kwargs): + extractor_classifier = ExtractorClassifier( + get_extractor(**kwargs), get_image_classifier(model_loader, model_identifier) + ) return extractor_classifier def get_formatter(): - formatter = EnumFormatter() + formatter = TransformerCompositor(EnumFormatter(), ResponseTransformer(), Snake2CamelCaseKeyFormatter()) return formatter class Pipeline: - def __init__(self, model_loader, model_identifier): - self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier), get_formatter()) + def __init__(self, model_loader, model_identifier, **kwargs): + self.pipe = rcompose(get_extractor_classifier(model_loader, model_identifier, **kwargs), get_formatter()) def __call__(self, pdf: bytes): yield from self.pipe(pdf) diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 90717d9..62ab529 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -19,8 +19,11 @@ def main(args): with open(args.pdf, "rb") as f: predictions = pipeline(f.read()) + with open("/tmp/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json", "w") as f: + json.dump(list(predictions), f, indent=2) + for prd in predictions: - print(json.dumps(prd, indent=1)) + print(json.dumps(prd, indent=2)) if __name__ == "__main__": diff --git a/test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf similarity index 100% rename from test/test_data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf rename to test/data/f2dc689ca794fccb8cd38b95f2bf6ba9.pdf diff --git a/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json new file mode 100644 index 0000000..76fbcb8 --- /dev/null +++ b/test/data/f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json @@ -0,0 +1,42 @@ +[ + { + "classification": { + "label": "formula", + "probabilities": { + "formula": 1.0, + "logo": 0.0, + "other": 0.0, + "signature": 0.0 + } + }, + "position": { + "x1": 321, + "x2": 515, + "y1": 300, + "y2": 494, + "pageNumber": 2 + }, + "geometry": { + "width": 389, + "height": 389 + }, + "filters": { + "geometry": { + "imageSize": { + "quotient": 0.2741, + "tooLarge": false, + "tooSmall": false + }, + "imageFormat": { + "quotient": 1.0, + "tooTall": false, + "tooWide": false + } + }, + "probability": { + "unconfident": false + }, + "allPassed": true + } + } +] \ No newline at end of file diff --git a/test/unit_tests/conftest.py b/test/unit_tests/conftest.py index dc31663..869e004 100644 --- a/test/unit_tests/conftest.py +++ b/test/unit_tests/conftest.py @@ -227,13 +227,13 @@ def map_labels(numeric_labels, classes): @pytest.fixture def metadata_plus_mapped_prediction(expected_predictions_mapped, metadata): - return [{"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)] + return [{"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped, metadata)] @pytest.fixture def metadata_formatted_plus_mapped_prediction_formatted(expected_predictions_mapped_and_formatted, metadata_formatted): return [ - {"prediction": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted) + {"classification": epm, **mdt} for epm, mdt in zip(expected_predictions_mapped_and_formatted, metadata_formatted) ] diff --git a/test/unit_tests/extractor_classifier_test.py b/test/unit_tests/extractor_classifier_test.py index 22c800e..3153ef4 100644 --- a/test/unit_tests/extractor_classifier_test.py +++ b/test/unit_tests/extractor_classifier_test.py @@ -10,5 +10,5 @@ from image_prediction.extractor_classifier.extractor_classifier import Extractor def test_extractor_classifier(image_extractor, image_classifier, images, batch_of_expected_string_labels): extractor_classifier = ExtractorClassifier(image_extractor, image_classifier) results = extractor_classifier(images) - labels = list(map(itemgetter("prediction"), results)) + labels = list(map(itemgetter("classification"), results)) assert labels == batch_of_expected_string_labels diff --git a/test/unit_tests/pipeline_test.py b/test/unit_tests/pipeline_test.py new file mode 100644 index 0000000..f66ca6c --- /dev/null +++ b/test/unit_tests/pipeline_test.py @@ -0,0 +1,18 @@ +import json +import os + +from image_prediction.default_objects import load_pipeline +from image_prediction.locations import TEST_DATA_DIR + + +def test_pipeline(): + + pipeline = load_pipeline(verbose=False) + + with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9.pdf"), "rb") as f: + predictions = list(pipeline(f.read())) + + with open(os.path.join(TEST_DATA_DIR, "f2dc689ca794fccb8cd38b95f2bf6ba9_predictions.json"), "r") as f: + expectations = json.load(f) + + assert predictions == expectations