146 lines
4.3 KiB
Python
146 lines
4.3 KiB
Python
from copy import deepcopy
|
|
from functools import reduce
|
|
from typing import Iterable
|
|
|
|
from PIL import Image
|
|
from funcy import juxt, first, rest
|
|
|
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
|
from image_prediction.info import Info
|
|
from image_prediction.utils.generic import until
|
|
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
|
|
|
|
|
|
def make_getter(key):
|
|
def getter(pair):
|
|
return pair.metadata[key]
|
|
|
|
return getter
|
|
|
|
|
|
def make_coord_getter(c):
|
|
return {
|
|
"x1": make_getter(Info.X1),
|
|
"x2": make_getter(Info.X2),
|
|
"y1": make_getter(Info.Y1),
|
|
"y2": make_getter(Info.Y2),
|
|
}[c]
|
|
|
|
|
|
def make_pair_merger(axis):
|
|
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
|
|
|
|
|
def make_group_merger(axis):
|
|
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
|
|
|
|
|
def make_length_getter(dim):
|
|
return {
|
|
"width": make_getter(Info.WIDTH),
|
|
"height": make_getter(Info.HEIGHT),
|
|
}[dim]
|
|
|
|
|
|
def merge_metadata_horizontally(m1, m2):
|
|
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
|
return merge_metadata(m1, m2)
|
|
|
|
|
|
def merge_metadata_vertically(m1, m2):
|
|
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
|
return merge_metadata(m1, m2)
|
|
|
|
|
|
def merge_metadata(m1, m2):
|
|
|
|
c1 = min(m1.c1, m2.c1)
|
|
c2 = max(m1.c2, m2.c2)
|
|
dim = m1.dim + m2.dim
|
|
|
|
merged = deepcopy(m1)
|
|
merged.dim = dim
|
|
merged.c1 = c1
|
|
merged.c2 = c2
|
|
|
|
return merged.wrapped
|
|
|
|
|
|
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
|
|
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
|
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
|
|
|
|
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
|
|
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
|
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
|
|
|
|
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
|
|
return concat_images(im1, im2, metadata, 0)
|
|
|
|
|
|
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
|
|
return concat_images(im1, im2, metadata, 1)
|
|
|
|
|
|
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
|
|
|
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
|
|
|
|
images = [im1, im2]
|
|
|
|
offsets = [0, *[im.size[axis] for im in images]]
|
|
|
|
for im, offset in zip(images, offsets):
|
|
box = (offset, 0) if not axis else (0, offset)
|
|
im_aggr.paste(im, box=box)
|
|
|
|
return im_aggr
|
|
|
|
|
|
def make_merger_aggregator(direction):
|
|
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
|
head H and aggregates non-adjacent in the tail T.
|
|
"""
|
|
|
|
def merger_aggregator(pairs):
|
|
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
|
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
|
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
|
if c2_getter(aggr) == c1_getter(pair):
|
|
aggr = pair_merger(aggr, pair)
|
|
return aggr, *non_aggr
|
|
else:
|
|
return aggr, pair, *non_aggr
|
|
|
|
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
|
|
# only in c1 -> c2 direction.
|
|
pairs = sorted(pairs, key=c1_getter)
|
|
aggregation_pair, pairs = juxt(first, rest)(pairs)
|
|
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
|
|
|
|
c1_getter = make_coord_getter(f"{direction}1")
|
|
c2_getter = make_coord_getter(f"{direction}2")
|
|
pair_merger = make_pair_merger(direction)
|
|
|
|
return merger_aggregator
|
|
|
|
|
|
def merge_group(group, direction):
|
|
def break_condition(pairs1, pairs2):
|
|
return len(pairs1) == len(pairs2)
|
|
|
|
reduce_group = make_merger_aggregator(direction)
|
|
return until(reduce_group, break_condition, group)
|
|
|
|
|
|
def merge_group_horizontally(group):
|
|
return merge_group(group, "x")
|
|
|
|
|
|
def merge_group_vertically(group):
|
|
return merge_group(group, "y")
|