2022-04-11 12:16:42 +02:00

167 lines
5.7 KiB
Python

from copy import deepcopy
from functools import reduce
from typing import Iterable, Callable, List
from PIL import Image
from funcy import juxt, first, rest, rcompose
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.stitching.grouping import CoordGrouper
from image_prediction.stitching.utils import make_coord_getter, flatten_groups_once
from image_prediction.utils.generic import until
from image_prediction.stitching.split_mapper import HorizontalSplitMapper, VerticalSplitMapper
def no_new_merges(pairs1, pairs2):
return len(pairs1) == len(pairs2)
def merge_along_both_axes(pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
pairs = merge_along_axis(pairs, "x")
pairs = list(merge_along_axis(pairs, "y"))
return pairs
def merge_along_axis(pairs: Iterable[ImageMetadataPair], axis) -> Iterable[ImageMetadataPair]:
"""Partially merges image-metadata pairs of adjacent images along a given axis. Needs to be iterated with
alternating axes until no more merges happen to merge all adjacent images.
Explanation:
Merging algorithm works as follows:
A dot represents a pair, a bracket a group and a colon a merged pair.
1) Start with pairs: (........)
2) Align on lesser: ([....] [....])
3) Align on greater: ([[..] [..]] [[....]])
4) Flatten once: ([..] [..] [....])
5) Merge orthogonally: ([:] [..] [:..])
6) Flatten once: (:..:..)
"""
def group_pairs_within_groups_by_greater_coordinate(groups):
return map(CoordGrouper(axis).group_pairs_by_greater_coordinate, groups)
def merge_groups_along_orthogonal_axis(groups):
return map(make_group_merger(axis), groups)
def group_pairs_by_lesser_coordinate(pairs):
return CoordGrouper(axis).group_pairs_by_lesser_coordinate(pairs)
return rcompose(
group_pairs_by_lesser_coordinate,
group_pairs_within_groups_by_greater_coordinate,
flatten_groups_once,
merge_groups_along_orthogonal_axis,
flatten_groups_once,
)(pairs)
def make_group_merger(axis):
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
def merge_group_vertically(group: Iterable[ImageMetadataPair]):
return merge_group(group, "y")
def merge_group_horizontally(group: Iterable[ImageMetadataPair]):
return merge_group(group, "x")
def merge_group(group: Iterable[ImageMetadataPair], direction):
reduce_group = make_merger_aggregator(direction)
return until(no_new_merges, reduce_group, group)
def make_merger_aggregator(axis) -> Callable[[Iterable[ImageMetadataPair]], Iterable[ImageMetadataPair]]:
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
head H and aggregates non-adjacent in the tail T.
"""
def merger_aggregator(pairs: Iterable[ImageMetadataPair]):
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
if c2_getter(aggr) == c1_getter(pair):
aggr = pair_merger(aggr, pair)
return aggr, *non_aggr
else:
return aggr, pair, *non_aggr
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
# only in c1 -> c2 direction.
pairs = sorted(pairs, key=c1_getter)
head_pair, pairs = juxt(first, rest)(pairs)
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [head_pair]))
c1_getter = make_coord_getter(f"{axis}1")
c2_getter = make_coord_getter(f"{axis}2")
pair_merger = make_pair_merger(axis)
return merger_aggregator
def make_pair_merger(axis):
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
def merge_metadata_vertically(m1: dict, m2: dict):
m1, m2 = map(VerticalSplitMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata_horizontally(m1: dict, m2: dict):
m1, m2 = map(HorizontalSplitMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata(m1: dict, m2: dict):
c1 = min(m1.c1, m2.c1)
c2 = max(m1.c2, m2.c2)
dim = m1.dim + m2.dim
merged = deepcopy(m1)
merged.dim = dim
merged.c1 = c1
merged.c2 = c2
return merged.wrapped
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 1)
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 0)
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
images = [im1, im2]
offsets = [0, *[im.size[axis] for im in images]]
for im, offset in zip(images, offsets):
box = (offset, 0) if not axis else (0, offset)
im_aggr.paste(im, box=box)
return im_aggr