refactoring
This commit is contained in:
parent
e276a5ec27
commit
51793d19e9
@ -7,7 +7,8 @@ from funcy import rcompose
|
|||||||
from image_prediction.classifier.classifier import Classifier
|
from image_prediction.classifier.classifier import Classifier
|
||||||
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||||
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
|
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
|
||||||
from image_prediction.utils import chunk_iterable, get_logger
|
from image_prediction.utils import get_logger
|
||||||
|
from image_prediction.utils.generic import chunk_iterable
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Iterable
|
|||||||
|
|
||||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||||
from image_prediction.utils import chunk_iterable
|
from image_prediction.utils.generic import chunk_iterable
|
||||||
|
|
||||||
|
|
||||||
class ExtractorClassifier:
|
class ExtractorClassifier:
|
||||||
|
|||||||
0
image_prediction/stitcher/__init__.py
Normal file
0
image_prediction/stitcher/__init__.py
Normal file
53
image_prediction/stitcher/stitcher.py
Normal file
53
image_prediction/stitcher/stitcher.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from itertools import groupby, chain
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
from funcy import compose, second
|
||||||
|
|
||||||
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
|
from image_prediction.stitcher.utils import make_coord_getter, make_group_merger
|
||||||
|
from image_prediction.utils.generic import until_convergence
|
||||||
|
|
||||||
|
|
||||||
|
class Stitcher:
|
||||||
|
@staticmethod
|
||||||
|
def groupby(pairs, coord_getter):
|
||||||
|
pairs = sorted(pairs, key=coord_getter)
|
||||||
|
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def other_axis(axis):
|
||||||
|
return "y" if axis == "x" else "x"
|
||||||
|
|
||||||
|
def merge_along_axis(self, pairs, axis):
|
||||||
|
def group_pairs_by_c1(pairs):
|
||||||
|
return self.groupby(pairs, c1_getter)
|
||||||
|
|
||||||
|
def group_by_c2(pairs):
|
||||||
|
return self.groupby(pairs, c2_getter)
|
||||||
|
|
||||||
|
def group_pairs_within_groups_by_c2(groups):
|
||||||
|
return map(group_by_c2, groups)
|
||||||
|
|
||||||
|
def merge_groups_along_orthogonal_axis(groups):
|
||||||
|
return map(group_merger, groups)
|
||||||
|
|
||||||
|
c1_getter = make_coord_getter(f"{self.other_axis(axis)}1")
|
||||||
|
c2_getter = make_coord_getter(f"{self.other_axis(axis)}2")
|
||||||
|
group_merger = make_group_merger(axis)
|
||||||
|
|
||||||
|
groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs)
|
||||||
|
groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1)
|
||||||
|
groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2)
|
||||||
|
groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)
|
||||||
|
pairs = chain(*groups_of_merged_pairs)
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
def merge_along_both_axes(self, pairs):
|
||||||
|
pairs = self.merge_along_axis(pairs, "x")
|
||||||
|
pairs = list(self.merge_along_axis(pairs, "y"))
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
|
||||||
|
return until_convergence(self.merge_along_both_axes, pairs)
|
||||||
142
image_prediction/stitcher/utils.py
Normal file
142
image_prediction/stitcher/utils.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from funcy import juxt, first, rest
|
||||||
|
|
||||||
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
|
from image_prediction.info import Info
|
||||||
|
from image_prediction.utils.generic import until_convergence
|
||||||
|
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
|
||||||
|
|
||||||
|
|
||||||
|
def make_getter(key):
|
||||||
|
def getter(pair):
|
||||||
|
return pair.metadata[key]
|
||||||
|
|
||||||
|
return getter
|
||||||
|
|
||||||
|
|
||||||
|
def make_coord_getter(c):
|
||||||
|
return {
|
||||||
|
"x1": make_getter(Info.X1),
|
||||||
|
"x2": make_getter(Info.X2),
|
||||||
|
"y1": make_getter(Info.Y1),
|
||||||
|
"y2": make_getter(Info.Y2),
|
||||||
|
}[c]
|
||||||
|
|
||||||
|
|
||||||
|
def make_pair_merger(axis):
|
||||||
|
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
||||||
|
|
||||||
|
|
||||||
|
def make_group_merger(axis):
|
||||||
|
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
||||||
|
|
||||||
|
|
||||||
|
def make_length_getter(dim):
|
||||||
|
return {
|
||||||
|
"width": make_getter(Info.WIDTH),
|
||||||
|
"height": make_getter(Info.HEIGHT),
|
||||||
|
}[dim]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_metadata_horizontally(m1, m2):
|
||||||
|
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
||||||
|
return merge_metadata(m1, m2)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_metadata_vertically(m1, m2):
|
||||||
|
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
||||||
|
return merge_metadata(m1, m2)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_metadata(m1, m2):
|
||||||
|
|
||||||
|
c1 = min(m1.c1, m2.c1)
|
||||||
|
c2 = max(m1.c2, m2.c2)
|
||||||
|
dim = m1.dim + m2.dim
|
||||||
|
|
||||||
|
merged = deepcopy(m1)
|
||||||
|
merged.dim = dim
|
||||||
|
merged.c1 = c1
|
||||||
|
merged.c2 = c2
|
||||||
|
|
||||||
|
return merged.wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
||||||
|
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
|
||||||
|
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
|
||||||
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
||||||
|
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
|
||||||
|
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
|
||||||
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
|
||||||
|
return concat_images(im1, im2, metadata, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
|
||||||
|
return concat_images(im1, im2, metadata, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
||||||
|
|
||||||
|
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
|
||||||
|
|
||||||
|
images = [im1, im2]
|
||||||
|
|
||||||
|
offsets = [0, *[im.size[axis] for im in images]]
|
||||||
|
|
||||||
|
for im, offset in zip(images, offsets):
|
||||||
|
box = (offset, 0) if not axis else (0, offset)
|
||||||
|
im_aggr.paste(im, box=box)
|
||||||
|
|
||||||
|
return im_aggr
|
||||||
|
|
||||||
|
|
||||||
|
def make_merger_aggregator(direction):
|
||||||
|
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
||||||
|
head H and aggregates non-adjacent in the tail T.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def merger_aggregator(pairs):
|
||||||
|
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
||||||
|
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
||||||
|
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
||||||
|
if c2_getter(aggr) == c1_getter(pair):
|
||||||
|
aggr = pair_merger(aggr, pair)
|
||||||
|
return aggr, *non_aggr
|
||||||
|
else:
|
||||||
|
return aggr, pair, *non_aggr
|
||||||
|
|
||||||
|
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
|
||||||
|
# only in c1 -> c2 direction.
|
||||||
|
pairs = sorted(pairs, key=c1_getter)
|
||||||
|
aggregation_pair, pairs = juxt(first, rest)(pairs)
|
||||||
|
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
|
||||||
|
|
||||||
|
c1_getter = make_coord_getter(f"{direction}1")
|
||||||
|
c2_getter = make_coord_getter(f"{direction}2")
|
||||||
|
pair_merger = make_pair_merger(direction)
|
||||||
|
|
||||||
|
return merger_aggregator
|
||||||
|
|
||||||
|
|
||||||
|
def merge_group(group, direction):
|
||||||
|
reduce_group = make_merger_aggregator(direction)
|
||||||
|
return until_convergence(reduce_group, group)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_group_horizontally(group):
|
||||||
|
return merge_group(group, "x")
|
||||||
|
|
||||||
|
|
||||||
|
def merge_group_vertically(group):
|
||||||
|
return merge_group(group, "y")
|
||||||
@ -1,8 +1,3 @@
|
|||||||
from itertools import takewhile, starmap, islice, repeat
|
|
||||||
from operator import truth
|
|
||||||
|
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
def chunk_iterable(iterable, chunk_size):
|
|
||||||
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
|
||||||
|
|||||||
14
image_prediction/utils/generic.py
Normal file
14
image_prediction/utils/generic.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from itertools import takewhile, starmap, islice, repeat
|
||||||
|
from operator import truth
|
||||||
|
|
||||||
|
from funcy import iterate
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_iterable(iterable, chunk_size):
|
||||||
|
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
||||||
|
|
||||||
|
|
||||||
|
def until_convergence(func, *args, **kwargs):
|
||||||
|
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
|
||||||
|
if len(a) == len(b):
|
||||||
|
return a
|
||||||
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from image_prediction.utils import chunk_iterable
|
from image_prediction.utils.generic import chunk_iterable
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||||
|
|||||||
@ -1,229 +1,46 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial, reduce
|
from functools import partial
|
||||||
from itertools import groupby
|
from itertools import starmap, repeat
|
||||||
from itertools import starmap, chain, repeat
|
from typing import List
|
||||||
from typing import Iterable, List
|
|
||||||
|
|
||||||
import fpdf
|
import fpdf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pdf2image
|
import pdf2image
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from funcy import merge, second, compose, rpartial, juxt, rest, first, one, iterate
|
from funcy import merge, juxt, one
|
||||||
|
|
||||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||||
from image_prediction.info import Info
|
from image_prediction.info import Info
|
||||||
from image_prediction.utils import chunk_iterable
|
from image_prediction.stitcher.stitcher import Stitcher
|
||||||
|
from image_prediction.stitcher.utils import (
|
||||||
|
make_coord_getter,
|
||||||
|
make_length_getter,
|
||||||
|
merge_metadata_horizontally,
|
||||||
|
merge_metadata_vertically,
|
||||||
|
merge_pair_horizontally,
|
||||||
|
merge_pair_vertically,
|
||||||
|
concat_images_horizontally,
|
||||||
|
concat_images_vertically,
|
||||||
|
merge_group_horizontally,
|
||||||
|
merge_group_vertically,
|
||||||
|
)
|
||||||
from test.conftest import (
|
from test.conftest import (
|
||||||
get_base_position_metadata,
|
get_base_position_metadata,
|
||||||
add_image,
|
add_image,
|
||||||
random_single_color_image_from_metadata,
|
random_single_color_image_from_metadata,
|
||||||
random_size_gray_image_from_metadata,
|
random_size_gray_image_from_metadata,
|
||||||
)
|
)
|
||||||
from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper
|
from test.utils.stitching import BoxSplitter
|
||||||
|
|
||||||
|
|
||||||
def make_getter(key):
|
|
||||||
def getter(pair):
|
|
||||||
return pair.metadata[key]
|
|
||||||
|
|
||||||
return getter
|
|
||||||
|
|
||||||
|
|
||||||
def make_coord_getter(c):
|
|
||||||
return {
|
|
||||||
"x1": make_getter(Info.X1),
|
|
||||||
"x2": make_getter(Info.X2),
|
|
||||||
"y1": make_getter(Info.Y1),
|
|
||||||
"y2": make_getter(Info.Y2),
|
|
||||||
}[c]
|
|
||||||
|
|
||||||
|
|
||||||
def make_pair_merger(axis):
|
|
||||||
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
|
||||||
|
|
||||||
|
|
||||||
def make_group_merger(axis):
|
|
||||||
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
|
||||||
|
|
||||||
|
|
||||||
def make_length_getter(dim):
|
|
||||||
return {
|
|
||||||
"width": make_getter(Info.WIDTH),
|
|
||||||
"height": make_getter(Info.HEIGHT),
|
|
||||||
}[dim]
|
|
||||||
|
|
||||||
|
|
||||||
x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
|
x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
|
||||||
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
||||||
|
|
||||||
|
|
||||||
def merge_metadata_horizontally(m1, m2):
|
|
||||||
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
|
||||||
return merge_metadata(m1, m2)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_metadata_vertically(m1, m2):
|
|
||||||
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
|
||||||
return merge_metadata(m1, m2)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_metadata(m1, m2):
|
|
||||||
|
|
||||||
c1 = min(m1.c1, m2.c1)
|
|
||||||
c2 = max(m1.c2, m2.c2)
|
|
||||||
dim = m1.dim + m2.dim
|
|
||||||
|
|
||||||
merged = deepcopy(m1)
|
|
||||||
merged.dim = dim
|
|
||||||
merged.c1 = c1
|
|
||||||
merged.c2 = c2
|
|
||||||
|
|
||||||
return merged.wrapped
|
|
||||||
|
|
||||||
|
|
||||||
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
||||||
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
|
|
||||||
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
|
|
||||||
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
||||||
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
|
|
||||||
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
|
|
||||||
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
||||||
|
|
||||||
|
|
||||||
# def merge_pair(p1, p2):
|
|
||||||
# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
|
|
||||||
|
|
||||||
|
|
||||||
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
|
|
||||||
return concat_images(im1, im2, metadata, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
|
|
||||||
return concat_images(im1, im2, metadata, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
|
||||||
|
|
||||||
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
|
|
||||||
|
|
||||||
images = [im1, im2]
|
|
||||||
|
|
||||||
offsets = [0, *[im.size[axis] for im in images]]
|
|
||||||
|
|
||||||
for im, offset in zip(images, offsets):
|
|
||||||
box = (offset, 0) if not axis else (0, offset)
|
|
||||||
im_aggr.paste(im, box=box)
|
|
||||||
|
|
||||||
return im_aggr
|
|
||||||
|
|
||||||
|
|
||||||
def make_merger_aggregator(direction):
|
|
||||||
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
|
||||||
head H and aggregates non-adjacent in the tail T.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def merger_aggregator(pairs):
|
|
||||||
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
|
||||||
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
|
||||||
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
|
||||||
if c2_getter(aggr) == c1_getter(pair):
|
|
||||||
aggr = pair_merger(aggr, pair)
|
|
||||||
return aggr, *non_aggr
|
|
||||||
else:
|
|
||||||
return aggr, pair, *non_aggr
|
|
||||||
|
|
||||||
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
|
|
||||||
# only in c1 -> c2 direction.
|
|
||||||
pairs = sorted(pairs, key=c1_getter)
|
|
||||||
aggregation_pair, pairs = juxt(first, rest)(pairs)
|
|
||||||
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
|
|
||||||
|
|
||||||
c1_getter = make_coord_getter(f"{direction}1")
|
|
||||||
c2_getter = make_coord_getter(f"{direction}2")
|
|
||||||
pair_merger = make_pair_merger(direction)
|
|
||||||
|
|
||||||
return merger_aggregator
|
|
||||||
|
|
||||||
|
|
||||||
def merge_group(group, direction):
|
|
||||||
reduce_group = make_merger_aggregator(direction)
|
|
||||||
return until_convergence(reduce_group, group)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_group_horizontally(group):
|
|
||||||
return merge_group(group, "x")
|
|
||||||
|
|
||||||
|
|
||||||
def merge_group_vertically(group):
|
|
||||||
return merge_group(group, "y")
|
|
||||||
|
|
||||||
|
|
||||||
class Stitcher:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def groupby(pairs, coord_getter):
|
|
||||||
pairs = sorted(pairs, key=coord_getter)
|
|
||||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def other_axis(axis):
|
|
||||||
return "y" if axis == "x" else "x"
|
|
||||||
|
|
||||||
def merge_along_axis(self, pairs, axis):
|
|
||||||
|
|
||||||
def group_pairs_by_c1(pairs):
|
|
||||||
return self.groupby(pairs, c1_getter)
|
|
||||||
|
|
||||||
def group_by_c2(pairs):
|
|
||||||
return self.groupby(pairs, c2_getter)
|
|
||||||
|
|
||||||
def group_pairs_within_groups_by_c2(groups):
|
|
||||||
return map(group_by_c2, groups)
|
|
||||||
|
|
||||||
def merge_groups_along_orthogonal_axis(groups):
|
|
||||||
return map(group_merger, groups)
|
|
||||||
|
|
||||||
c1_getter = make_coord_getter(f"{self.other_axis(axis)}1")
|
|
||||||
c2_getter = make_coord_getter(f"{self.other_axis(axis)}2")
|
|
||||||
group_merger = make_group_merger(axis)
|
|
||||||
|
|
||||||
groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs)
|
|
||||||
groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1)
|
|
||||||
groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2)
|
|
||||||
groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)
|
|
||||||
pairs = chain(*groups_of_merged_pairs)
|
|
||||||
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
def merge_along_both_axes(self, pairs):
|
|
||||||
|
|
||||||
pairs = self.merge_along_axis(pairs, "x")
|
|
||||||
pairs = list(self.merge_along_axis(pairs, "y"))
|
|
||||||
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
|
|
||||||
return until_convergence(self.merge_along_both_axes, pairs)
|
|
||||||
|
|
||||||
|
|
||||||
def until_convergence(func, *args, **kwargs):
|
|
||||||
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
|
|
||||||
if len(a) == len(b):
|
|
||||||
return a
|
|
||||||
|
|
||||||
|
|
||||||
#####################################
|
#####################################
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("width", [160])
|
|
||||||
# @pytest.mark.parametrize("height", [90])
|
|
||||||
# @pytest.mark.parametrize("page_width", [int(160 * 1.1)])
|
|
||||||
# @pytest.mark.parametrize("page_height", [int(90 * 1.1)])
|
|
||||||
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||||
pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0]
|
pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0]
|
||||||
assert pair_stitched.metadata == base_patch_metadata
|
assert pair_stitched.metadata == base_patch_metadata
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user