refactoring

This commit is contained in:
Matthias Bisping 2022-04-14 15:22:41 +02:00
parent 03e7b00cfd
commit 13513db5a1
7 changed files with 41 additions and 24 deletions

View File

@ -3,7 +3,7 @@ from functools import reduce
from typing import Iterable, Callable, List
from PIL import Image
from funcy import juxt, first, rest, rcompose, rpartial
from funcy import juxt, first, rest, rcompose, rpartial, complement, ilen
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
@ -13,8 +13,22 @@ from image_prediction.stitching.utils import make_coord_getter, flatten_groups_o
from image_prediction.utils.generic import until
def no_new_merges(pairs1, pairs2):
return len(pairs1) == len(pairs2)
def make_merger_sentinel():
def no_new_mergers(pairs):
nonlocal number_of_pairs_so_far
number_of_pairs_now = len(pairs)
if number_of_pairs_now == number_of_pairs_so_far:
return True
else:
number_of_pairs_so_far = number_of_pairs_now
return False
number_of_pairs_so_far = -1
return no_new_mergers
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
@ -72,7 +86,8 @@ def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0):
def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0):
reduce_group = make_merger_aggregator(direction, tolerance=tolerance)
return until(no_new_merges, reduce_group, group)
no_new_mergers = make_merger_sentinel()
return until(no_new_mergers, reduce_group, group)
def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:

View File

@ -3,11 +3,13 @@ from typing import Iterable, List
from funcy import rpartial
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges
from image_prediction.stitching.merging import merge_along_both_axes, make_merger_sentinel
from image_prediction.utils.generic import until
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
images."""
return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs)
no_new_mergers = make_merger_sentinel()
merge = rpartial(merge_along_both_axes, tolerance)
return until(no_new_mergers, merge, pairs)

View File

@ -1,7 +1,5 @@
from funcy import iterate, chunks
from funcy import iterate, first
def until(cond, func, *args, **kwargs):
for a, b in chunks(2, iterate(func, *args, **kwargs)):
if cond(a, b):
return a
return first(filter(cond, iterate(func, *args, **kwargs)))

View File

@ -7,17 +7,20 @@ import tempfile
from functools import partial
from itertools import starmap
from operator import itemgetter
from typing import Iterable
import fpdf
import numpy as np
import pytest
from PIL import Image
from frozendict import frozendict
from funcy import rcompose, merge
from image_prediction.classifier.classifier import Classifier
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.exceptions import (
UnknownEstimatorAdapter,
UnknownImageExtractor,
@ -513,3 +516,11 @@ def random_single_color_image_from_metadata(metadata):
def gray_image_from_metadata(metadata):
image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100))
return image
def images_equal(im1: Image, im2: Image, **kwargs):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
def metadata_equal(mdat1: Iterable, mdat2: Iterable):
return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2))

View File

@ -3,17 +3,15 @@ from operator import itemgetter
import fitz
import fpdf
import numpy as np
import pytest
from PIL import Image
from funcy import first, rest
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
from image_prediction.extraction import extract_images_from_pdf
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel
from image_prediction.info import Info
from test.conftest import add_image, pdf_stream
from test.conftest import add_image, pdf_stream, images_equal, metadata_equal
@pytest.mark.parametrize("extractor_type", ["mock"])
@ -29,8 +27,8 @@ def test_image_extractor_mock(image_extractor, images):
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha):
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
if not alpha:
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
assert list(metadata_extracted) == metadata
all(any(images_equal(imex, im) for im in images) for imex in images_extracted)
assert metadata_equal(metadata_extracted, metadata)
@pytest.mark.parametrize("batch_size", [1, 2, 16])

View File

@ -1,20 +1,16 @@
import json
import os
from copy import deepcopy
from copy import deepcopy
from functools import partial
from itertools import starmap, repeat
from operator import itemgetter
from typing import List
import fpdf
import numpy as np
import pdf2image
import pytest
from PIL import Image
from funcy import juxt, one, first
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.formatter.formatters.enum import ReverseEnumFormatter
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
@ -38,6 +34,7 @@ from test.conftest import (
add_image,
random_single_color_image_from_metadata,
gray_image_from_metadata,
images_equal,
)
from test.utils.stitching import BoxSplitter
@ -147,10 +144,6 @@ def test_merge_pairs_vertically(vertical_merge_test_pairs):
assert pair_equal(pr_merged, pr_merged_expected)
def images_equal(im1: Image, im2: Image, **kwargs):
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
@pytest.fixture
def horizontal_merge_test_pairs(horizontal_merge_test_metadata):
images = map(gray_image_from_metadata, horizontal_merge_test_metadata)

View File

@ -5,4 +5,4 @@ def test_until():
def f(x):
return x / 2
assert until(lambda x, y: x - y == 0, f, 1) == 0
assert until(lambda x: x == 0, f, 1) == 0