152 lines
5.2 KiB
Python
152 lines
5.2 KiB
Python
from copy import deepcopy
|
|
from functools import reduce
|
|
from typing import Iterable, Callable
|
|
|
|
from PIL import Image
|
|
from funcy import juxt, first, rest, rcompose
|
|
|
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
|
from image_prediction.info import Info
|
|
from image_prediction.stitching.grouping import CoordGrouper
|
|
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
|
|
from image_prediction.utils.generic import until
|
|
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
|
|
|
|
|
|
def no_new_merges(pairs1, pairs2):
|
|
return len(pairs1) == len(pairs2)
|
|
|
|
|
|
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]):
|
|
pairs = merge_along_axis(pairs, "x")
|
|
pairs = list(merge_along_axis(pairs, "y"))
|
|
|
|
return pairs
|
|
|
|
|
|
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis):
|
|
def group_pairs_within_groups_by_greater_coordinate(groups):
|
|
return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups)
|
|
|
|
def merge_groups_along_orthogonal_axis(groups):
|
|
return map(make_group_merger(axis), groups)
|
|
|
|
def group_pairs_by_lesser_coordinate(pairs):
|
|
return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs)
|
|
|
|
return rcompose(
|
|
group_pairs_by_lesser_coordinate, # pairs -> groups of pairs aligned on one edge
|
|
group_pairs_within_groups_by_greater_coordinate, # -> groups of pairs fully aligned on orthogonal axis
|
|
flatten_groups_once, # groups of groups of pairs -> groups of pairs
|
|
merge_groups_along_orthogonal_axis,
|
|
flatten_groups_once, # groups of pairs -> pairs
|
|
)(pairs)
|
|
|
|
|
|
def make_group_merger(axis):
|
|
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
|
|
|
|
|
def merge_group_vertically(group: Iterable[ImageMetadataPair]):
|
|
return merge_group(group, "y")
|
|
|
|
|
|
def merge_group_horizontally(group: Iterable[ImageMetadataPair]):
|
|
return merge_group(group, "x")
|
|
|
|
|
|
def merge_group(group: Iterable[ImageMetadataPair], direction):
|
|
reduce_group = make_merger_aggregator(direction)
|
|
return until(no_new_merges, reduce_group, group)
|
|
|
|
|
|
def make_merger_aggregator(axis) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
|
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
|
head H and aggregates non-adjacent in the tail T.
|
|
"""
|
|
|
|
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
|
|
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
|
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
|
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
|
if c2_getter(aggr) == c1_getter(pair):
|
|
aggr = pair_merger(aggr, pair)
|
|
return aggr, *non_aggr
|
|
else:
|
|
return aggr, pair, *non_aggr
|
|
|
|
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
|
|
# only in c1 -> c2 direction.
|
|
pairs = sorted(pairs, key=c1_getter)
|
|
aggregation_pair, pairs = juxt(first, rest)(pairs)
|
|
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
|
|
|
|
c1_getter = make_coord_getter(f"{axis}1")
|
|
c2_getter = make_coord_getter(f"{axis}2")
|
|
pair_merger = make_pair_merger(axis)
|
|
|
|
return merger_aggregator
|
|
|
|
|
|
def make_pair_merger(axis):
|
|
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
|
|
|
|
|
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
|
|
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
|
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
|
|
|
|
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
|
|
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
|
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
|
|
|
|
def merge_metadata_vertically(m1: dict, m2: dict):
|
|
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
|
return merge_metadata(m1, m2)
|
|
|
|
|
|
def merge_metadata_horizontally(m1: dict, m2: dict):
|
|
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
|
return merge_metadata(m1, m2)
|
|
|
|
|
|
def merge_metadata(m1: dict, m2: dict):
|
|
|
|
c1 = min(m1.c1, m2.c1)
|
|
c2 = max(m1.c2, m2.c2)
|
|
dim = m1.dim + m2.dim
|
|
|
|
merged = deepcopy(m1)
|
|
merged.dim = dim
|
|
merged.c1 = c1
|
|
merged.c2 = c2
|
|
|
|
return merged.wrapped
|
|
|
|
|
|
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
|
|
return concat_images(im1, im2, metadata, 1)
|
|
|
|
|
|
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
|
|
return concat_images(im1, im2, metadata, 0)
|
|
|
|
|
|
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
|
|
|
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
|
|
|
|
images = [im1, im2]
|
|
|
|
offsets = [0, *[im.size[axis] for im in images]]
|
|
|
|
for im, offset in zip(images, offsets):
|
|
box = (offset, 0) if not axis else (0, offset)
|
|
im_aggr.paste(im, box=box)
|
|
|
|
return im_aggr
|