From 13513db5a198f4447c9e7a8dc9187615fc645281 Mon Sep 17 00:00:00 2001 From: Matthias Bisping Date: Thu, 14 Apr 2022 15:22:41 +0200 Subject: [PATCH] refactoring --- image_prediction/stitching/merging.py | 23 +++++++++++++++++++---- image_prediction/stitching/stitching.py | 6 ++++-- image_prediction/utils/generic.py | 6 ++---- test/conftest.py | 11 +++++++++++ test/unit_tests/image_extractor_test.py | 8 +++----- test/unit_tests/image_stitching_test.py | 9 +-------- test/unit_tests/utils_test.py | 2 +- 7 files changed, 41 insertions(+), 24 deletions(-) diff --git a/image_prediction/stitching/merging.py b/image_prediction/stitching/merging.py index f5fe22d..2c3fbc4 100644 --- a/image_prediction/stitching/merging.py +++ b/image_prediction/stitching/merging.py @@ -3,7 +3,7 @@ from functools import reduce from typing import Iterable, Callable, List from PIL import Image -from funcy import juxt, first, rest, rcompose, rpartial +from funcy import juxt, first, rest, rcompose, rpartial, complement, ilen from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info @@ -13,8 +13,22 @@ from image_prediction.stitching.utils import make_coord_getter, flatten_groups_o from image_prediction.utils.generic import until -def no_new_merges(pairs1, pairs2): - return len(pairs1) == len(pairs2) +def make_merger_sentinel(): + def no_new_mergers(pairs): + nonlocal number_of_pairs_so_far + + number_of_pairs_now = len(pairs) + + if number_of_pairs_now == number_of_pairs_so_far: + return True + + else: + number_of_pairs_so_far = number_of_pairs_now + return False + + number_of_pairs_so_far = -1 + + return no_new_mergers def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]: @@ -72,7 +86,8 @@ def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0): def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0): reduce_group = make_merger_aggregator(direction, tolerance=tolerance) - return until(no_new_merges, reduce_group, group) + no_new_mergers = make_merger_sentinel() + return until(no_new_mergers, reduce_group, group) def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]: diff --git a/image_prediction/stitching/stitching.py b/image_prediction/stitching/stitching.py index 0cf3e9e..9d98bd3 100644 --- a/image_prediction/stitching/stitching.py +++ b/image_prediction/stitching/stitching.py @@ -3,11 +3,13 @@ from typing import Iterable, List from funcy import rpartial from image_prediction.image_extractor.extractor import ImageMetadataPair -from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges +from image_prediction.stitching.merging import merge_along_both_axes, make_merger_sentinel from image_prediction.utils.generic import until def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]: """Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent images.""" - return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs) + no_new_mergers = make_merger_sentinel() + merge = rpartial(merge_along_both_axes, tolerance) + return until(no_new_mergers, merge, pairs) diff --git a/image_prediction/utils/generic.py b/image_prediction/utils/generic.py index 2ef1522..98cf612 100644 --- a/image_prediction/utils/generic.py +++ b/image_prediction/utils/generic.py @@ -1,7 +1,5 @@ -from funcy import iterate, chunks +from funcy import iterate, first def until(cond, func, *args, **kwargs): - for a, b in chunks(2, iterate(func, *args, **kwargs)): - if cond(a, b): - return a + return first(filter(cond, iterate(func, *args, **kwargs))) diff --git a/test/conftest.py b/test/conftest.py index a919c38..465a66f 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -7,17 +7,20 @@ import tempfile from functools import partial from itertools import starmap from operator import itemgetter +from typing import Iterable import fpdf import numpy as np import pytest from PIL import Image +from frozendict import frozendict from funcy import rcompose, merge from image_prediction.classifier.classifier import Classifier from image_prediction.classifier.image_classifier import ImageClassifier from image_prediction.estimator.adapter.adapter import EstimatorAdapter from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor +from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor from image_prediction.exceptions import ( UnknownEstimatorAdapter, UnknownImageExtractor, @@ -513,3 +516,11 @@ def random_single_color_image_from_metadata(metadata): def gray_image_from_metadata(metadata): image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100)) return image + + +def images_equal(im1: Image, im2: Image, **kwargs): + return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs) + + +def metadata_equal(mdat1: Iterable, mdat2: Iterable): + return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2)) diff --git a/test/unit_tests/image_extractor_test.py b/test/unit_tests/image_extractor_test.py index a6c6078..5bc2b42 100644 --- a/test/unit_tests/image_extractor_test.py +++ b/test/unit_tests/image_extractor_test.py @@ -3,17 +3,15 @@ from operator import itemgetter import fitz import fpdf -import numpy as np import pytest from PIL import Image from funcy import first, rest -from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor from image_prediction.extraction import extract_images_from_pdf from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel from image_prediction.info import Info -from test.conftest import add_image, pdf_stream +from test.conftest import add_image, pdf_stream, images_equal, metadata_equal @pytest.mark.parametrize("extractor_type", ["mock"]) @@ -29,8 +27,8 @@ def test_image_extractor_mock(image_extractor, images): def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha): images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor)) if not alpha: - assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images)) - assert list(metadata_extracted) == metadata + all(any(images_equal(imex, im) for im in images) for imex in images_extracted) + assert metadata_equal(metadata_extracted, metadata) @pytest.mark.parametrize("batch_size", [1, 2, 16]) diff --git a/test/unit_tests/image_stitching_test.py b/test/unit_tests/image_stitching_test.py index fda2a62..9af487e 100644 --- a/test/unit_tests/image_stitching_test.py +++ b/test/unit_tests/image_stitching_test.py @@ -1,20 +1,16 @@ import json import os from copy import deepcopy -from copy import deepcopy from functools import partial from itertools import starmap, repeat from operator import itemgetter from typing import List import fpdf -import numpy as np import pdf2image import pytest -from PIL import Image from funcy import juxt, one, first -from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor from image_prediction.formatter.formatters.enum import ReverseEnumFormatter from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info @@ -38,6 +34,7 @@ from test.conftest import ( add_image, random_single_color_image_from_metadata, gray_image_from_metadata, + images_equal, ) from test.utils.stitching import BoxSplitter @@ -147,10 +144,6 @@ def test_merge_pairs_vertically(vertical_merge_test_pairs): assert pair_equal(pr_merged, pr_merged_expected) -def images_equal(im1: Image, im2: Image, **kwargs): - return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs) - - @pytest.fixture def horizontal_merge_test_pairs(horizontal_merge_test_metadata): images = map(gray_image_from_metadata, horizontal_merge_test_metadata) diff --git a/test/unit_tests/utils_test.py b/test/unit_tests/utils_test.py index 8ff6b69..5263af8 100644 --- a/test/unit_tests/utils_test.py +++ b/test/unit_tests/utils_test.py @@ -5,4 +5,4 @@ def test_until(): def f(x): return x / 2 - assert until(lambda x, y: x - y == 0, f, 1) == 0 + assert until(lambda x: x == 0, f, 1) == 0