from copy import deepcopy from functools import reduce from typing import Iterable, Callable, List from PIL import Image from funcy import juxt, first, rest, rcompose, rpartial, complement, ilen from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from image_prediction.stitching.grouping import CoordGrouper from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once, validate_box from image_prediction.utils.generic import until def make_merger_sentinel(): def no_new_mergers(pairs): nonlocal number_of_pairs_so_far number_of_pairs_now = len(pairs) if number_of_pairs_now == number_of_pairs_so_far: return True else: number_of_pairs_so_far = number_of_pairs_now return False number_of_pairs_so_far = -1 return no_new_mergers def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]: pairs = merge_along_axis(pairs, "x", tolerance=tolerance) pairs = list(merge_along_axis(pairs, "y", tolerance=tolerance)) return pairs def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> Iterable[ImageMetadataPair]: """Partially merges image-metadata pairs of adjacent images along a given axis. Needs to be iterated with alternating axes until no more merges happen to merge all adjacent images. Explanation: Merging algorithm works as follows: A dot represents a pair, a bracket a group and a colon a merged pair. 1) Start with pairs: (........) 2) Align on lesser: ([....] [....]) 3) Align on greater: ([[..] [..]] [[....]]) 4) Flatten once: ([..] [..] [....]) 5) Merge orthogonally: ([:] [..] [:..]) 6) Flatten once: (:..:..) """ def group_pairs_within_groups_by_greater_coordinate(groups): return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups) def merge_groups_along_orthogonal_axis(groups): return map(rpartial(make_group_merger(axis), tolerance), groups) def group_pairs_by_lesser_coordinate(pairs): return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs) return rcompose( group_pairs_by_lesser_coordinate, group_pairs_within_groups_by_greater_coordinate, flatten_groups_once, merge_groups_along_orthogonal_axis, flatten_groups_once, )(pairs) def make_group_merger(axis): return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis] def merge_group_vertically(group: Iterable[ImageMetadataPair], tolerance=0): return merge_group(group, "y", tolerance=tolerance) def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0): return merge_group(group, "x", tolerance=tolerance) def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0): reduce_group = make_merger_aggregator(direction, tolerance=tolerance) no_new_mergers = make_merger_sentinel() return until(no_new_mergers, reduce_group, group) def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]: """Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the head H and aggregates non-adjacent in the tail T. Note: When tolerance > 0, the bounding box of the merged image no longer matches the bounding box of the mereged metadata. This is intended behaviour, but might be not be expected by the caller. """ def merger_aggregator(pairs: Iterable[ImageMetadataPair]): def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" aggr, non_aggr = juxt(first, rest)(pairs_aggr) if abs(c2_getter(aggr) - c1_getter(pair)) <= tolerance: aggr = pair_merger(aggr, pair) return aggr, *non_aggr else: return aggr, pair, *non_aggr # Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens # only in c1 -> c2 direction. pairs = sorted(pairs, key=c1_getter) head_pair, pairs = juxt(first, rest)(pairs) return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair])) assert tolerance >= 0 c1_getter = make_coord_getter(f"{axis}1") c2_getter = make_coord_getter(f"{axis}2") pair_merger = make_pair_merger(axis) return merger_aggregator def make_pair_merger(axis): return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis] def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair): metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata) image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged) return ImageMetadataPair(image_concatenated, metadata_merged) def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged) return ImageMetadataPair(image_concatenated, metadata_merged) def merge_metadata_vertically(m1: dict, m2: dict): m1, m2 = map(VerticalSplitMapper, [m1, m2]) return merge_metadata(m1, m2) def merge_metadata_horizontally(m1: dict, m2: dict): m1, m2 = map(HorizontalSplitMapper, [m1, m2]) return merge_metadata(m1, m2) def merge_metadata(m1: dict, m2: dict): c1 = min(m1.c1, m2.c1) c2 = max(m1.c2, m2.c2) dim = abs(c2 - c1) merged = deepcopy(m1) merged.dim = dim merged.c1 = c1 merged.c2 = c2 validate_box(merged.wrapped) return merged.wrapped def concat_images_vertically(im1: Image, im2: Image, metadata: dict): return concat_images(im1, im2, metadata, 1) def concat_images_horizontally(im1: Image, im2: Image, metadata: dict): return concat_images(im1, im2, metadata, 0) def concat_images(im1: Image, im2: Image, metadata: dict, axis): im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT])) images = [im1, im2] offsets = 0, im1.size[axis], im_aggr.size[axis] - im2.size[axis] for im, offset in zip(images, offsets): box = (offset, 0) if not axis else (0, offset) im_aggr.paste(im, box=box) return im_aggr