Matthias Bisping 5967149c49 refactoring
2022-04-07 21:49:55 +02:00

146 lines
4.3 KiB
Python

from copy import deepcopy
from functools import reduce
from typing import Iterable
from PIL import Image
from funcy import juxt, first, rest
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.utils.generic import until
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
def make_getter(key):
def getter(pair):
return pair.metadata[key]
return getter
def make_coord_getter(c):
return {
"x1": make_getter(Info.X1),
"x2": make_getter(Info.X2),
"y1": make_getter(Info.Y1),
"y2": make_getter(Info.Y2),
}[c]
def make_pair_merger(axis):
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
def make_group_merger(axis):
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
def make_length_getter(dim):
return {
"width": make_getter(Info.WIDTH),
"height": make_getter(Info.HEIGHT),
}[dim]
def merge_metadata_horizontally(m1, m2):
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata_vertically(m1, m2):
m1, m2 = map(VerticalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata(m1, m2):
c1 = min(m1.c1, m2.c1)
c2 = max(m1.c2, m2.c2)
dim = m1.dim + m2.dim
merged = deepcopy(m1)
merged.dim = dim
merged.c1 = c1
merged.c2 = c2
return merged.wrapped
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 0)
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 1)
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
images = [im1, im2]
offsets = [0, *[im.size[axis] for im in images]]
for im, offset in zip(images, offsets):
box = (offset, 0) if not axis else (0, offset)
im_aggr.paste(im, box=box)
return im_aggr
def make_merger_aggregator(direction):
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
head H and aggregates non-adjacent in the tail T.
"""
def merger_aggregator(pairs):
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
if c2_getter(aggr) == c1_getter(pair):
aggr = pair_merger(aggr, pair)
return aggr, *non_aggr
else:
return aggr, pair, *non_aggr
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
# only in c1 -> c2 direction.
pairs = sorted(pairs, key=c1_getter)
aggregation_pair, pairs = juxt(first, rest)(pairs)
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
c1_getter = make_coord_getter(f"{direction}1")
c2_getter = make_coord_getter(f"{direction}2")
pair_merger = make_pair_merger(direction)
return merger_aggregator
def merge_group(group, direction):
def break_condition(pairs1, pairs2):
return len(pairs1) == len(pairs2)
reduce_group = make_merger_aggregator(direction)
return until(reduce_group, break_condition, group)
def merge_group_horizontally(group):
return merge_group(group, "x")
def merge_group_vertically(group):
return merge_group(group, "y")