from copy import deepcopy from functools import reduce from typing import Iterable, Callable, List from PIL import Image from funcy import juxt, first, rest, rcompose from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from image_prediction.stitching.grouping import CoordGrouper from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once from image_prediction.utils.generic import until from test.utils.stitching import HorizontalSplitMapper, VerticalMapper def no_new_merges(pairs1, pairs2): return len(pairs1) == len(pairs2) def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]: pairs = merge_along_axis(pairs, "x") pairs = list(merge_along_axis(pairs, "y")) return pairs def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis) -> Iterable[ImageMetadataPair]: """Partially merges image-metadata pairs of adjacent images along a given axis. Needs to be iterated with alternating axes until no more merges happen to merge all adjacent images. Explanation: Merging algorithm works as follows: A dot represents a pair, a bracket a group and a colon a merged pair. 1) Start with pairs: (........) 2) Align on lesser: ([....] [....]) 3) Align on greater: ([[..] [..]] [[....]]) 4) Flatten once: ([..] [..] [....]) 5) Merge orthogonally: ([:] [..] [:..]) 6) Flatten once: (:..:..) """ def group_pairs_within_groups_by_greater_coordinate(groups): return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups) def merge_groups_along_orthogonal_axis(groups): return map(make_group_merger(axis), groups) def group_pairs_by_lesser_coordinate(pairs): return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs) return rcompose( group_pairs_by_lesser_coordinate, group_pairs_within_groups_by_greater_coordinate, flatten_groups_once, merge_groups_along_orthogonal_axis, flatten_groups_once, )(pairs) def make_group_merger(axis): return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis] def merge_group_vertically(group: Iterable[ImageMetadataPair]): return merge_group(group, "y") def merge_group_horizontally(group: Iterable[ImageMetadataPair]): return merge_group(group, "x") def merge_group(group: Iterable[ImageMetadataPair], direction): reduce_group = make_merger_aggregator(direction) return until(no_new_merges, reduce_group, group) def make_merger_aggregator(axis) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]: """Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the head H and aggregates non-adjacent in the tail T. """ def merger_aggregator(pairs: Iterable[ImageMetadataPair]): def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" aggr, non_aggr = juxt(first, rest)(pairs_aggr) if c2_getter(aggr) == c1_getter(pair): aggr = pair_merger(aggr, pair) return aggr, *non_aggr else: return aggr, pair, *non_aggr # Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens # only in c1 -> c2 direction. pairs = sorted(pairs, key=c1_getter) head_pair, pairs = juxt(first, rest)(pairs) return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair])) c1_getter = make_coord_getter(f"{axis}1") c2_getter = make_coord_getter(f"{axis}2") pair_merger = make_pair_merger(axis) return merger_aggregator def make_pair_merger(axis): return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis] def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair): metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata) image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged) return ImageMetadataPair(image_concatenated, metadata_merged) def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged) return ImageMetadataPair(image_concatenated, metadata_merged) def merge_metadata_vertically(m1: dict, m2: dict): m1, m2 = map(VerticalMapper, [m1, m2]) return merge_metadata(m1, m2) def merge_metadata_horizontally(m1: dict, m2: dict): m1, m2 = map(HorizontalSplitMapper, [m1, m2]) return merge_metadata(m1, m2) def merge_metadata(m1: dict, m2: dict): c1 = min(m1.c1, m2.c1) c2 = max(m1.c2, m2.c2) dim = m1.dim + m2.dim merged = deepcopy(m1) merged.dim = dim merged.c1 = c1 merged.c2 = c2 return merged.wrapped def concat_images_vertically(im1: Image, im2: Image, metadata: dict): return concat_images(im1, im2, metadata, 1) def concat_images_horizontally(im1: Image, im2: Image, metadata: dict): return concat_images(im1, im2, metadata, 0) def concat_images(im1: Image, im2: Image, metadata: dict, axis): im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT])) images = [im1, im2] offsets = [0, *[im.size[axis] for im in images]] for im, offset in zip(images, offsets): box = (offset, 0) if not axis else (0, offset) im_aggr.paste(im, box=box) return im_aggr