from copy import deepcopy from functools import reduce from typing import Iterable from PIL import Image from funcy import juxt, first, rest from image_prediction.image_extractor.extractor import ImageMetadataPair from image_prediction.info import Info from image_prediction.utils.generic import until from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper def make_getter(key): def getter(pair): return pair.metadata[key] return getter def make_coord_getter(c): return { "x1": make_getter(Info.X1), "x2": make_getter(Info.X2), "y1": make_getter(Info.Y1), "y2": make_getter(Info.Y2), }[c] def make_pair_merger(axis): return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis] def make_group_merger(axis): return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis] def make_length_getter(dim): return { "width": make_getter(Info.WIDTH), "height": make_getter(Info.HEIGHT), }[dim] def merge_metadata_horizontally(m1, m2): m1, m2 = map(HorizontalKeyMapper, [m1, m2]) return merge_metadata(m1, m2) def merge_metadata_vertically(m1, m2): m1, m2 = map(VerticalKeyMapper, [m1, m2]) return merge_metadata(m1, m2) def merge_metadata(m1, m2): c1 = min(m1.c1, m2.c1) c2 = max(m1.c2, m2.c2) dim = m1.dim + m2.dim merged = deepcopy(m1) merged.dim = dim merged.c1 = c1 merged.c2 = c2 return merged.wrapped def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair): metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata) image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged) return ImageMetadataPair(image_concatenated, metadata_merged) def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair): metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata) image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged) return ImageMetadataPair(image_concatenated, metadata_merged) def concat_images_horizontally(im1: Image, im2: Image, metadata: dict): return concat_images(im1, im2, metadata, 0) def concat_images_vertically(im1: Image, im2: Image, metadata: dict): return concat_images(im1, im2, metadata, 1) def concat_images(im1: Image, im2: Image, metadata: dict, axis): im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT])) images = [im1, im2] offsets = [0, *[im.size[axis] for im in images]] for im, offset in zip(images, offsets): box = (offset, 0) if not axis else (0, offset) im_aggr.paste(im, box=box) return im_aggr def make_merger_aggregator(direction): """Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the head H and aggregates non-adjacent in the tail T. """ def merger_aggregator(pairs): def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair): """Keeps the image that is being merged with as the head and aggregates non-mergables in the tail.""" aggr, non_aggr = juxt(first, rest)(pairs_aggr) if c2_getter(aggr) == c1_getter(pair): aggr = pair_merger(aggr, pair) return aggr, *non_aggr else: return aggr, pair, *non_aggr # Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens # only in c1 -> c2 direction. pairs = sorted(pairs, key=c1_getter) aggregation_pair, pairs = juxt(first, rest)(pairs) return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair])) c1_getter = make_coord_getter(f"{direction}1") c2_getter = make_coord_getter(f"{direction}2") pair_merger = make_pair_merger(direction) return merger_aggregator def merge_group(group, direction): def break_condition(pairs1, pairs2): return len(pairs1) == len(pairs2) reduce_group = make_merger_aggregator(direction) return until(reduce_group, break_condition, group) def merge_group_horizontally(group): return merge_group(group, "x") def merge_group_vertically(group): return merge_group(group, "y")