173 lines
6.2 KiB
Python
173 lines
6.2 KiB
Python
from copy import deepcopy
|
|
from functools import reduce
|
|
from typing import Iterable, Callable, List
|
|
|
|
from PIL import Image
|
|
from funcy import juxt, first, rest, rcompose, rpartial
|
|
|
|
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
|
from image_prediction.info import Info
|
|
from image_prediction.stitching.grouping import CoordGrouper
|
|
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
|
|
from image_prediction.utils.generic import until
|
|
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
|
|
|
|
|
|
def no_new_merges(pairs1, pairs2):
|
|
return len(pairs1) == len(pairs2)
|
|
|
|
|
|
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair], tolerance=0) -> List[ImageMetadataPair]:
|
|
pairs = merge_along_axis(pairs, "x", tolerance=tolerance)
|
|
pairs = list(merge_along_axis(pairs, "y", tolerance=tolerance))
|
|
|
|
return pairs
|
|
|
|
|
|
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis, tolerance=0) -> Iterable[ImageMetadataPair]:
|
|
"""Partially merges image-metadata pairs of adjacent images along a given axis. Needs to be iterated with
|
|
alternating axes until no more merges happen to merge all adjacent images.
|
|
|
|
Explanation:
|
|
|
|
Merging algorithm works as follows:
|
|
A dot represents a pair, a bracket a group and a colon a merged pair.
|
|
1) Start with pairs: (........)
|
|
2) Align on lesser: ([....] [....])
|
|
3) Align on greater: ([[..] [..]] [[....]])
|
|
4) Flatten once: ([..] [..] [....])
|
|
5) Merge orthogonally: ([:] [..] [:..])
|
|
6) Flatten once: (:..:..)
|
|
"""
|
|
|
|
def group_pairs_within_groups_by_greater_coordinate(groups):
|
|
return map(CoordGrouper(axis, tolerance=tolerance).group_pairs_by_greater_coordinate, groups)
|
|
|
|
def merge_groups_along_orthogonal_axis(groups):
|
|
return map(rpartial(make_group_merger(axis), tolerance), groups)
|
|
|
|
def group_pairs_by_lesser_coordinate(pairs):
|
|
return CoordGrouper(axis, tolerance=tolerance).group_pairs_by_lesser_coordinate(pairs)
|
|
|
|
return rcompose(
|
|
group_pairs_by_lesser_coordinate,
|
|
group_pairs_within_groups_by_greater_coordinate,
|
|
flatten_groups_once,
|
|
merge_groups_along_orthogonal_axis,
|
|
flatten_groups_once,
|
|
)(pairs)
|
|
|
|
|
|
def make_group_merger(axis):
|
|
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
|
|
|
|
|
def merge_group_vertically(group: Iterable[ImageMetadataPair], tolerance=0):
|
|
return merge_group(group, "y", tolerance=tolerance)
|
|
|
|
|
|
def merge_group_horizontally(group: Iterable[ImageMetadataPair], tolerance=0):
|
|
return merge_group(group, "x", tolerance=tolerance)
|
|
|
|
|
|
def merge_group(group: Iterable[ImageMetadataPair], direction, tolerance=0):
|
|
reduce_group = make_merger_aggregator(direction, tolerance=tolerance)
|
|
return until(no_new_merges, reduce_group, group)
|
|
|
|
|
|
def make_merger_aggregator(axis, tolerance=0) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
|
|
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
|
head H and aggregates non-adjacent in the tail T.
|
|
|
|
Note:
|
|
When tolerance > 0, the bounding box of the merged image no longer matches the bounding box of the mereged
|
|
metadata. This is intended behaviour, but might be not be expected by the caller.
|
|
"""
|
|
|
|
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
|
|
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
|
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
|
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
|
if abs(c2_getter(aggr) - c1_getter(pair)) <= tolerance:
|
|
aggr = pair_merger(aggr, pair)
|
|
return aggr, *non_aggr
|
|
else:
|
|
return aggr, pair, *non_aggr
|
|
|
|
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
|
|
# only in c1 -> c2 direction.
|
|
pairs = sorted(pairs, key=c1_getter)
|
|
head_pair, pairs = juxt(first, rest)(pairs)
|
|
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair]))
|
|
|
|
assert tolerance >= 0
|
|
|
|
c1_getter = make_coord_getter(f"{axis}1")
|
|
c2_getter = make_coord_getter(f"{axis}2")
|
|
pair_merger = make_pair_merger(axis)
|
|
|
|
return merger_aggregator
|
|
|
|
|
|
def make_pair_merger(axis):
|
|
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
|
|
|
|
|
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
|
|
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
|
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
|
|
|
|
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
|
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
|
|
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
|
|
return ImageMetadataPair(image_concatenated, metadata_merged)
|
|
|
|
|
|
def merge_metadata_vertically(m1: dict, m2: dict):
|
|
m1, m2 = map(VerticalSplitMapper, [m1, m2])
|
|
return merge_metadata(m1, m2)
|
|
|
|
|
|
def merge_metadata_horizontally(m1: dict, m2: dict):
|
|
m1, m2 = map(HorizontalSplitMapper, [m1, m2])
|
|
return merge_metadata(m1, m2)
|
|
|
|
|
|
def merge_metadata(m1: dict, m2: dict):
|
|
|
|
c1 = min(m1.c1, m2.c1)
|
|
c2 = max(m1.c2, m2.c2)
|
|
dim = m1.dim + m2.dim
|
|
|
|
merged = deepcopy(m1)
|
|
merged.dim = dim
|
|
merged.c1 = c1
|
|
merged.c2 = c2
|
|
|
|
return merged.wrapped
|
|
|
|
|
|
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
|
|
return concat_images(im1, im2, metadata, 1)
|
|
|
|
|
|
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
|
|
return concat_images(im1, im2, metadata, 0)
|
|
|
|
|
|
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
|
|
|
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
|
|
|
|
images = [im1, im2]
|
|
|
|
offsets = [0, *[im.size[axis] for im in images]]
|
|
|
|
for im, offset in zip(images, offsets):
|
|
box = (offset, 0) if not axis else (0, offset)
|
|
im_aggr.paste(im, box=box)
|
|
|
|
return im_aggr
|