refactoring

This commit is contained in:
Matthias Bisping 2022-04-07 21:39:01 +02:00
parent e276a5ec27
commit 51793d19e9
9 changed files with 231 additions and 209 deletions

View File

@ -7,7 +7,8 @@ from funcy import rcompose
from image_prediction.classifier.classifier import Classifier
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
from image_prediction.utils import chunk_iterable, get_logger
from image_prediction.utils import get_logger
from image_prediction.utils.generic import chunk_iterable
logger = get_logger()

View File

@ -3,7 +3,7 @@ from typing import Iterable
from image_prediction.classifier.image_classifier import ImageClassifier
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
from image_prediction.utils import chunk_iterable
from image_prediction.utils.generic import chunk_iterable
class ExtractorClassifier:

View File

View File

@ -0,0 +1,53 @@
from itertools import groupby, chain
from typing import Iterable, List
from funcy import compose, second
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.stitcher.utils import make_coord_getter, make_group_merger
from image_prediction.utils.generic import until_convergence
class Stitcher:
@staticmethod
def groupby(pairs, coord_getter):
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
@staticmethod
def other_axis(axis):
return "y" if axis == "x" else "x"
def merge_along_axis(self, pairs, axis):
def group_pairs_by_c1(pairs):
return self.groupby(pairs, c1_getter)
def group_by_c2(pairs):
return self.groupby(pairs, c2_getter)
def group_pairs_within_groups_by_c2(groups):
return map(group_by_c2, groups)
def merge_groups_along_orthogonal_axis(groups):
return map(group_merger, groups)
c1_getter = make_coord_getter(f"{self.other_axis(axis)}1")
c2_getter = make_coord_getter(f"{self.other_axis(axis)}2")
group_merger = make_group_merger(axis)
groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs)
groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1)
groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2)
groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)
pairs = chain(*groups_of_merged_pairs)
return pairs
def merge_along_both_axes(self, pairs):
pairs = self.merge_along_axis(pairs, "x")
pairs = list(self.merge_along_axis(pairs, "y"))
return pairs
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
return until_convergence(self.merge_along_both_axes, pairs)

View File

@ -0,0 +1,142 @@
from copy import deepcopy
from functools import reduce
from typing import Iterable
from PIL import Image
from funcy import juxt, first, rest
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.utils.generic import until_convergence
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
def make_getter(key):
def getter(pair):
return pair.metadata[key]
return getter
def make_coord_getter(c):
return {
"x1": make_getter(Info.X1),
"x2": make_getter(Info.X2),
"y1": make_getter(Info.Y1),
"y2": make_getter(Info.Y2),
}[c]
def make_pair_merger(axis):
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
def make_group_merger(axis):
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
def make_length_getter(dim):
return {
"width": make_getter(Info.WIDTH),
"height": make_getter(Info.HEIGHT),
}[dim]
def merge_metadata_horizontally(m1, m2):
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata_vertically(m1, m2):
m1, m2 = map(VerticalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata(m1, m2):
c1 = min(m1.c1, m2.c1)
c2 = max(m1.c2, m2.c2)
dim = m1.dim + m2.dim
merged = deepcopy(m1)
merged.dim = dim
merged.c1 = c1
merged.c2 = c2
return merged.wrapped
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 0)
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 1)
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
images = [im1, im2]
offsets = [0, *[im.size[axis] for im in images]]
for im, offset in zip(images, offsets):
box = (offset, 0) if not axis else (0, offset)
im_aggr.paste(im, box=box)
return im_aggr
def make_merger_aggregator(direction):
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
head H and aggregates non-adjacent in the tail T.
"""
def merger_aggregator(pairs):
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
if c2_getter(aggr) == c1_getter(pair):
aggr = pair_merger(aggr, pair)
return aggr, *non_aggr
else:
return aggr, pair, *non_aggr
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
# only in c1 -> c2 direction.
pairs = sorted(pairs, key=c1_getter)
aggregation_pair, pairs = juxt(first, rest)(pairs)
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
c1_getter = make_coord_getter(f"{direction}1")
c2_getter = make_coord_getter(f"{direction}2")
pair_merger = make_pair_merger(direction)
return merger_aggregator
def merge_group(group, direction):
reduce_group = make_merger_aggregator(direction)
return until_convergence(reduce_group, group)
def merge_group_horizontally(group):
return merge_group(group, "x")
def merge_group_vertically(group):
return merge_group(group, "y")

View File

@ -1,8 +1,3 @@
from itertools import takewhile, starmap, islice, repeat
from operator import truth
from .logger import get_logger
def chunk_iterable(iterable, chunk_size):
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))

View File

@ -0,0 +1,14 @@
from itertools import takewhile, starmap, islice, repeat
from operator import truth
from funcy import iterate
def chunk_iterable(iterable, chunk_size):
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
def until_convergence(func, *args, **kwargs):
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
if len(a) == len(b):
return a

View File

@ -1,6 +1,6 @@
import pytest
from image_prediction.utils import chunk_iterable
from image_prediction.utils.generic import chunk_iterable
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])

View File

@ -1,229 +1,46 @@
from copy import deepcopy
from functools import partial, reduce
from itertools import groupby
from itertools import starmap, chain, repeat
from typing import Iterable, List
from functools import partial
from itertools import starmap, repeat
from typing import List
import fpdf
import numpy as np
import pdf2image
import pytest
from PIL import Image
from funcy import merge, second, compose, rpartial, juxt, rest, first, one, iterate
from funcy import merge, juxt, one
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
from image_prediction.image_extractor.extractor import ImageMetadataPair
from image_prediction.info import Info
from image_prediction.utils import chunk_iterable
from image_prediction.stitcher.stitcher import Stitcher
from image_prediction.stitcher.utils import (
make_coord_getter,
make_length_getter,
merge_metadata_horizontally,
merge_metadata_vertically,
merge_pair_horizontally,
merge_pair_vertically,
concat_images_horizontally,
concat_images_vertically,
merge_group_horizontally,
merge_group_vertically,
)
from test.conftest import (
get_base_position_metadata,
add_image,
random_single_color_image_from_metadata,
random_size_gray_image_from_metadata,
)
from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper
def make_getter(key):
def getter(pair):
return pair.metadata[key]
return getter
def make_coord_getter(c):
return {
"x1": make_getter(Info.X1),
"x2": make_getter(Info.X2),
"y1": make_getter(Info.Y1),
"y2": make_getter(Info.Y2),
}[c]
def make_pair_merger(axis):
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
def make_group_merger(axis):
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
def make_length_getter(dim):
return {
"width": make_getter(Info.WIDTH),
"height": make_getter(Info.HEIGHT),
}[dim]
from test.utils.stitching import BoxSplitter
x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
width_getter, height_getter = map(make_length_getter, ("width", "height"))
def merge_metadata_horizontally(m1, m2):
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata_vertically(m1, m2):
m1, m2 = map(VerticalKeyMapper, [m1, m2])
return merge_metadata(m1, m2)
def merge_metadata(m1, m2):
c1 = min(m1.c1, m2.c1)
c2 = max(m1.c2, m2.c2)
dim = m1.dim + m2.dim
merged = deepcopy(m1)
merged.dim = dim
merged.c1 = c1
merged.c2 = c2
return merged.wrapped
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
return ImageMetadataPair(image_concatenated, metadata_merged)
# def merge_pair(p1, p2):
# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 0)
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
return concat_images(im1, im2, metadata, 1)
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
images = [im1, im2]
offsets = [0, *[im.size[axis] for im in images]]
for im, offset in zip(images, offsets):
box = (offset, 0) if not axis else (0, offset)
im_aggr.paste(im, box=box)
return im_aggr
def make_merger_aggregator(direction):
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
head H and aggregates non-adjacent in the tail T.
"""
def merger_aggregator(pairs):
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
if c2_getter(aggr) == c1_getter(pair):
aggr = pair_merger(aggr, pair)
return aggr, *non_aggr
else:
return aggr, pair, *non_aggr
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
# only in c1 -> c2 direction.
pairs = sorted(pairs, key=c1_getter)
aggregation_pair, pairs = juxt(first, rest)(pairs)
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
c1_getter = make_coord_getter(f"{direction}1")
c2_getter = make_coord_getter(f"{direction}2")
pair_merger = make_pair_merger(direction)
return merger_aggregator
def merge_group(group, direction):
reduce_group = make_merger_aggregator(direction)
return until_convergence(reduce_group, group)
def merge_group_horizontally(group):
return merge_group(group, "x")
def merge_group_vertically(group):
return merge_group(group, "y")
class Stitcher:
@staticmethod
def groupby(pairs, coord_getter):
pairs = sorted(pairs, key=coord_getter)
return map(compose(list, second), groupby(pairs, coord_getter))
@staticmethod
def other_axis(axis):
return "y" if axis == "x" else "x"
def merge_along_axis(self, pairs, axis):
def group_pairs_by_c1(pairs):
return self.groupby(pairs, c1_getter)
def group_by_c2(pairs):
return self.groupby(pairs, c2_getter)
def group_pairs_within_groups_by_c2(groups):
return map(group_by_c2, groups)
def merge_groups_along_orthogonal_axis(groups):
return map(group_merger, groups)
c1_getter = make_coord_getter(f"{self.other_axis(axis)}1")
c2_getter = make_coord_getter(f"{self.other_axis(axis)}2")
group_merger = make_group_merger(axis)
groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs)
groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1)
groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2)
groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)
pairs = chain(*groups_of_merged_pairs)
return pairs
def merge_along_both_axes(self, pairs):
pairs = self.merge_along_axis(pairs, "x")
pairs = list(self.merge_along_axis(pairs, "y"))
return pairs
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
return until_convergence(self.merge_along_both_axes, pairs)
def until_convergence(func, *args, **kwargs):
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
if len(a) == len(b):
return a
#####################################
# @pytest.mark.parametrize("width", [160])
# @pytest.mark.parametrize("height", [90])
# @pytest.mark.parametrize("page_width", [int(160 * 1.1)])
# @pytest.mark.parametrize("page_height", [int(90 * 1.1)])
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0]
assert pair_stitched.metadata == base_patch_metadata