refactoring
This commit is contained in:
parent
03e7b00cfd
commit
13513db5a1
@ -3,7 +3,7 @@ from functools import reduce
|
||||
from typing import Iterable, Callable, List
|
||||
|
||||
from PIL import Image
|
||||
from funcy import juxt, first, rest, rcompose, rpartial
|
||||
from funcy import juxt, first, rest, rcompose, rpartial, complement, ilen
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
@ -13,8 +13,22 @@ from image_prediction.stitching.utils import make_coord_getter, flatten_groups_o
|
||||
from image_prediction.utils.generic import until
|
||||
|
||||
|
||||
def no_new_merges(pairs1, pairs2):
|
||||
return len(pairs1) == len(pairs2)
|
||||
def make_merger_sentinel():
|
||||
def no_new_mergers(pairs):
|
||||
nonlocal number_of_pairs_so_far
|
||||
|
||||
number_of_pairs_now = len(pairs)
|
||||
|
||||
if number_of_pairs_now == number_of_pairs_so_far:
|
||||
return True
|
||||
|
||||
else:
|
||||
number_of_pairs_so_far = number_of_pairs_now
|
||||
return False
|
||||
|
||||
number_of_pairs_so_far = -1
|
||||
|
||||
return no_new_mergers
|
||||
|
||||
|
||||
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
|
||||
@ -72,7 +86,8 @@ def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0):
|
||||
|
||||
def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0):
|
||||
reduce_group = make_merger_aggregator(direction, tolerance=tolerance)
|
||||
return until(no_new_merges, reduce_group, group)
|
||||
no_new_mergers = make_merger_sentinel()
|
||||
return until(no_new_mergers, reduce_group, group)
|
||||
|
||||
|
||||
def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
||||
|
||||
@ -3,11 +3,13 @@ from typing import Iterable, List
|
||||
from funcy import rpartial
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.stitching.merging import merge_along_both_axes, no_new_merges
|
||||
from image_prediction.stitching.merging import merge_along_both_axes, make_merger_sentinel
|
||||
from image_prediction.utils.generic import until
|
||||
|
||||
|
||||
def stitch_pairs(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
|
||||
"""Given a collection of image-metadata pairs from the same pages, combines all pairs that constitute adjacent
|
||||
images."""
|
||||
return until(no_new_merges, rpartial(merge_along_both_axes, tolerance), pairs)
|
||||
no_new_mergers = make_merger_sentinel()
|
||||
merge = rpartial(merge_along_both_axes, tolerance)
|
||||
return until(no_new_mergers, merge, pairs)
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from funcy import iterate, chunks
|
||||
from funcy import iterate, first
|
||||
|
||||
|
||||
def until(cond, func, *args, **kwargs):
|
||||
for a, b in chunks(2, iterate(func, *args, **kwargs)):
|
||||
if cond(a, b):
|
||||
return a
|
||||
return first(filter(cond, iterate(func, *args, **kwargs)))
|
||||
|
||||
@ -7,17 +7,20 @@ import tempfile
|
||||
from functools import partial
|
||||
from itertools import starmap
|
||||
from operator import itemgetter
|
||||
from typing import Iterable
|
||||
|
||||
import fpdf
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from frozendict import frozendict
|
||||
from funcy import rcompose, merge
|
||||
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.estimator.adapter.adapter import EstimatorAdapter
|
||||
from image_prediction.estimator.preprocessor.preprocessors.basic import BasicPreprocessor
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||
from image_prediction.exceptions import (
|
||||
UnknownEstimatorAdapter,
|
||||
UnknownImageExtractor,
|
||||
@ -513,3 +516,11 @@ def random_single_color_image_from_metadata(metadata):
|
||||
def gray_image_from_metadata(metadata):
|
||||
image = Image.new("RGB", (metadata[Info.WIDTH], metadata[Info.HEIGHT]), color=(100, 100, 100))
|
||||
return image
|
||||
|
||||
|
||||
def images_equal(im1: Image, im2: Image, **kwargs):
|
||||
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
|
||||
|
||||
|
||||
def metadata_equal(mdat1: Iterable, mdat2: Iterable):
|
||||
return set(map(frozendict, mdat1)) == set(map(frozendict, mdat2))
|
||||
|
||||
@ -3,17 +3,15 @@ from operator import itemgetter
|
||||
|
||||
import fitz
|
||||
import fpdf
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from funcy import first, rest
|
||||
|
||||
from image_prediction.estimator.preprocessor.utils import images_to_batch_tensor
|
||||
from image_prediction.extraction import extract_images_from_pdf
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.image_extractor.extractors.parsable import extract_pages, get_image_infos, has_alpha_channel
|
||||
from image_prediction.info import Info
|
||||
from test.conftest import add_image, pdf_stream
|
||||
from test.conftest import add_image, pdf_stream, images_equal, metadata_equal
|
||||
|
||||
|
||||
@pytest.mark.parametrize("extractor_type", ["mock"])
|
||||
@ -29,8 +27,8 @@ def test_image_extractor_mock(image_extractor, images):
|
||||
def test_parsable_pdf_image_extractor(image_extractor, pdf, images, metadata, input_size, alpha):
|
||||
images_extracted, metadata_extracted = map(list, extract_images_from_pdf(pdf, image_extractor))
|
||||
if not alpha:
|
||||
assert np.allclose(images_to_batch_tensor(images_extracted), images_to_batch_tensor(images))
|
||||
assert list(metadata_extracted) == metadata
|
||||
all(any(images_equal(imex, im) for im in images) for imex in images_extracted)
|
||||
assert metadata_equal(metadata_extracted, metadata)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 16])
|
||||
|
||||
@ -1,20 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from itertools import starmap, repeat
|
||||
from operator import itemgetter
|
||||
from typing import List
|
||||
|
||||
import fpdf
|
||||
import numpy as np
|
||||
import pdf2image
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from funcy import juxt, one, first
|
||||
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||
from image_prediction.formatter.formatters.enum import ReverseEnumFormatter
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
@ -38,6 +34,7 @@ from test.conftest import (
|
||||
add_image,
|
||||
random_single_color_image_from_metadata,
|
||||
gray_image_from_metadata,
|
||||
images_equal,
|
||||
)
|
||||
from test.utils.stitching import BoxSplitter
|
||||
|
||||
@ -147,10 +144,6 @@ def test_merge_pairs_vertically(vertical_merge_test_pairs):
|
||||
assert pair_equal(pr_merged, pr_merged_expected)
|
||||
|
||||
|
||||
def images_equal(im1: Image, im2: Image, **kwargs):
|
||||
return np.allclose(image_to_normalized_tensor(im1), image_to_normalized_tensor(im2), **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def horizontal_merge_test_pairs(horizontal_merge_test_metadata):
|
||||
images = map(gray_image_from_metadata, horizontal_merge_test_metadata)
|
||||
|
||||
@ -5,4 +5,4 @@ def test_until():
|
||||
def f(x):
|
||||
return x / 2
|
||||
|
||||
assert until(lambda x, y: x - y == 0, f, 1) == 0
|
||||
assert until(lambda x: x == 0, f, 1) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user