diff --git a/test/unit_tests/image_stitcher_test.py b/test/unit_tests/image_stitcher_test.py index 8e18f67..fa85f69 100644 --- a/test/unit_tests/image_stitcher_test.py +++ b/test/unit_tests/image_stitcher_test.py @@ -38,6 +38,13 @@ def make_coord_getter(c): }[c] +def make_pair_merger(direction): + return { + "y": merge_pair_vertically, + "x": merge_pair_horizontally + }[direction] + + def make_length_getter(dim): return { "width": make_getter(Info.WIDTH), @@ -126,20 +133,24 @@ class Stitcher: groups = map(merge_group, groups) -def merge_group(group): +def merge_group(group, direction="y"): def merge_with(aggregation_pair, pairs): - def aggregate_on_head(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): - """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail""" + def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): + """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" aggr, non_aggr = juxt(first, rest)(pairs_aggr) - if y2_getter(aggr) == y1_getter(pair): - aggr = merge_pair_vertically(aggr, pair) + if c2_getter(aggr) == c1_getter(pair): + aggr = pair_merger(aggr, pair) return aggr, *non_aggr else: return aggr, pair, *non_aggr - return list(reduce(aggregate_on_head, pairs, [aggregation_pair])) + return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair])) + + c1_getter = make_coord_getter(f"{direction}1") + c2_getter = make_coord_getter(f"{direction}2") + pair_merger = make_pair_merger(direction) pairs = list(group)