refactoring
This commit is contained in:
parent
e276a5ec27
commit
51793d19e9
@ -7,7 +7,8 @@ from funcy import rcompose
|
||||
from image_prediction.classifier.classifier import Classifier
|
||||
from image_prediction.estimator.preprocessor.preprocessor import Preprocessor
|
||||
from image_prediction.estimator.preprocessor.preprocessors.identity import IdentityPreprocessor
|
||||
from image_prediction.utils import chunk_iterable, get_logger
|
||||
from image_prediction.utils import get_logger
|
||||
from image_prediction.utils.generic import chunk_iterable
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Iterable
|
||||
|
||||
from image_prediction.classifier.image_classifier import ImageClassifier
|
||||
from image_prediction.image_extractor.extractor import ImageExtractor, ImageMetadataPair
|
||||
from image_prediction.utils import chunk_iterable
|
||||
from image_prediction.utils.generic import chunk_iterable
|
||||
|
||||
|
||||
class ExtractorClassifier:
|
||||
|
||||
0
image_prediction/stitcher/__init__.py
Normal file
0
image_prediction/stitcher/__init__.py
Normal file
53
image_prediction/stitcher/stitcher.py
Normal file
53
image_prediction/stitcher/stitcher.py
Normal file
@ -0,0 +1,53 @@
|
||||
from itertools import groupby, chain
|
||||
from typing import Iterable, List
|
||||
|
||||
from funcy import compose, second
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.stitcher.utils import make_coord_getter, make_group_merger
|
||||
from image_prediction.utils.generic import until_convergence
|
||||
|
||||
|
||||
class Stitcher:
|
||||
@staticmethod
|
||||
def groupby(pairs, coord_getter):
|
||||
pairs = sorted(pairs, key=coord_getter)
|
||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||
|
||||
@staticmethod
|
||||
def other_axis(axis):
|
||||
return "y" if axis == "x" else "x"
|
||||
|
||||
def merge_along_axis(self, pairs, axis):
|
||||
def group_pairs_by_c1(pairs):
|
||||
return self.groupby(pairs, c1_getter)
|
||||
|
||||
def group_by_c2(pairs):
|
||||
return self.groupby(pairs, c2_getter)
|
||||
|
||||
def group_pairs_within_groups_by_c2(groups):
|
||||
return map(group_by_c2, groups)
|
||||
|
||||
def merge_groups_along_orthogonal_axis(groups):
|
||||
return map(group_merger, groups)
|
||||
|
||||
c1_getter = make_coord_getter(f"{self.other_axis(axis)}1")
|
||||
c2_getter = make_coord_getter(f"{self.other_axis(axis)}2")
|
||||
group_merger = make_group_merger(axis)
|
||||
|
||||
groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs)
|
||||
groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1)
|
||||
groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2)
|
||||
groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)
|
||||
pairs = chain(*groups_of_merged_pairs)
|
||||
|
||||
return pairs
|
||||
|
||||
def merge_along_both_axes(self, pairs):
|
||||
pairs = self.merge_along_axis(pairs, "x")
|
||||
pairs = list(self.merge_along_axis(pairs, "y"))
|
||||
|
||||
return pairs
|
||||
|
||||
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
|
||||
return until_convergence(self.merge_along_both_axes, pairs)
|
||||
142
image_prediction/stitcher/utils.py
Normal file
142
image_prediction/stitcher/utils.py
Normal file
@ -0,0 +1,142 @@
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Iterable
|
||||
|
||||
from PIL import Image
|
||||
from funcy import juxt, first, rest
|
||||
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.utils.generic import until_convergence
|
||||
from test.utils.stitching import HorizontalKeyMapper, VerticalKeyMapper
|
||||
|
||||
|
||||
def make_getter(key):
|
||||
def getter(pair):
|
||||
return pair.metadata[key]
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
def make_coord_getter(c):
|
||||
return {
|
||||
"x1": make_getter(Info.X1),
|
||||
"x2": make_getter(Info.X2),
|
||||
"y1": make_getter(Info.Y1),
|
||||
"y2": make_getter(Info.Y2),
|
||||
}[c]
|
||||
|
||||
|
||||
def make_pair_merger(axis):
|
||||
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
||||
|
||||
|
||||
def make_group_merger(axis):
|
||||
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
||||
|
||||
|
||||
def make_length_getter(dim):
|
||||
return {
|
||||
"width": make_getter(Info.WIDTH),
|
||||
"height": make_getter(Info.HEIGHT),
|
||||
}[dim]
|
||||
|
||||
|
||||
def merge_metadata_horizontally(m1, m2):
|
||||
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
||||
return merge_metadata(m1, m2)
|
||||
|
||||
|
||||
def merge_metadata_vertically(m1, m2):
|
||||
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
||||
return merge_metadata(m1, m2)
|
||||
|
||||
|
||||
def merge_metadata(m1, m2):
|
||||
|
||||
c1 = min(m1.c1, m2.c1)
|
||||
c2 = max(m1.c2, m2.c2)
|
||||
dim = m1.dim + m2.dim
|
||||
|
||||
merged = deepcopy(m1)
|
||||
merged.dim = dim
|
||||
merged.c1 = c1
|
||||
merged.c2 = c2
|
||||
|
||||
return merged.wrapped
|
||||
|
||||
|
||||
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
||||
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
|
||||
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
|
||||
return ImageMetadataPair(image_concatenated, metadata_merged)
|
||||
|
||||
|
||||
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
||||
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
|
||||
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
|
||||
return ImageMetadataPair(image_concatenated, metadata_merged)
|
||||
|
||||
|
||||
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
|
||||
return concat_images(im1, im2, metadata, 0)
|
||||
|
||||
|
||||
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
|
||||
return concat_images(im1, im2, metadata, 1)
|
||||
|
||||
|
||||
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
||||
|
||||
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
|
||||
|
||||
images = [im1, im2]
|
||||
|
||||
offsets = [0, *[im.size[axis] for im in images]]
|
||||
|
||||
for im, offset in zip(images, offsets):
|
||||
box = (offset, 0) if not axis else (0, offset)
|
||||
im_aggr.paste(im, box=box)
|
||||
|
||||
return im_aggr
|
||||
|
||||
|
||||
def make_merger_aggregator(direction):
|
||||
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
||||
head H and aggregates non-adjacent in the tail T.
|
||||
"""
|
||||
|
||||
def merger_aggregator(pairs):
|
||||
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
||||
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
||||
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
||||
if c2_getter(aggr) == c1_getter(pair):
|
||||
aggr = pair_merger(aggr, pair)
|
||||
return aggr, *non_aggr
|
||||
else:
|
||||
return aggr, pair, *non_aggr
|
||||
|
||||
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
|
||||
# only in c1 -> c2 direction.
|
||||
pairs = sorted(pairs, key=c1_getter)
|
||||
aggregation_pair, pairs = juxt(first, rest)(pairs)
|
||||
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
|
||||
|
||||
c1_getter = make_coord_getter(f"{direction}1")
|
||||
c2_getter = make_coord_getter(f"{direction}2")
|
||||
pair_merger = make_pair_merger(direction)
|
||||
|
||||
return merger_aggregator
|
||||
|
||||
|
||||
def merge_group(group, direction):
|
||||
reduce_group = make_merger_aggregator(direction)
|
||||
return until_convergence(reduce_group, group)
|
||||
|
||||
|
||||
def merge_group_horizontally(group):
|
||||
return merge_group(group, "x")
|
||||
|
||||
|
||||
def merge_group_vertically(group):
|
||||
return merge_group(group, "y")
|
||||
@ -1,8 +1,3 @@
|
||||
from itertools import takewhile, starmap, islice, repeat
|
||||
from operator import truth
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
|
||||
def chunk_iterable(iterable, chunk_size):
|
||||
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
||||
|
||||
14
image_prediction/utils/generic.py
Normal file
14
image_prediction/utils/generic.py
Normal file
@ -0,0 +1,14 @@
|
||||
from itertools import takewhile, starmap, islice, repeat
|
||||
from operator import truth
|
||||
|
||||
from funcy import iterate
|
||||
|
||||
|
||||
def chunk_iterable(iterable, chunk_size):
|
||||
return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), chunk_size)))))
|
||||
|
||||
|
||||
def until_convergence(func, *args, **kwargs):
|
||||
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
|
||||
if len(a) == len(b):
|
||||
return a
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from image_prediction.utils import chunk_iterable
|
||||
from image_prediction.utils.generic import chunk_iterable
|
||||
|
||||
|
||||
@pytest.mark.parametrize("estimator_type", ["mock", "keras"])
|
||||
|
||||
@ -1,229 +1,46 @@
|
||||
from copy import deepcopy
|
||||
from functools import partial, reduce
|
||||
from itertools import groupby
|
||||
from itertools import starmap, chain, repeat
|
||||
from typing import Iterable, List
|
||||
from functools import partial
|
||||
from itertools import starmap, repeat
|
||||
from typing import List
|
||||
|
||||
import fpdf
|
||||
import numpy as np
|
||||
import pdf2image
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from funcy import merge, second, compose, rpartial, juxt, rest, first, one, iterate
|
||||
from funcy import merge, juxt, one
|
||||
|
||||
from image_prediction.estimator.preprocessor.utils import image_to_normalized_tensor
|
||||
from image_prediction.image_extractor.extractor import ImageMetadataPair
|
||||
from image_prediction.info import Info
|
||||
from image_prediction.utils import chunk_iterable
|
||||
from image_prediction.stitcher.stitcher import Stitcher
|
||||
from image_prediction.stitcher.utils import (
|
||||
make_coord_getter,
|
||||
make_length_getter,
|
||||
merge_metadata_horizontally,
|
||||
merge_metadata_vertically,
|
||||
merge_pair_horizontally,
|
||||
merge_pair_vertically,
|
||||
concat_images_horizontally,
|
||||
concat_images_vertically,
|
||||
merge_group_horizontally,
|
||||
merge_group_vertically,
|
||||
)
|
||||
from test.conftest import (
|
||||
get_base_position_metadata,
|
||||
add_image,
|
||||
random_single_color_image_from_metadata,
|
||||
random_size_gray_image_from_metadata,
|
||||
)
|
||||
from test.utils.stitching import BoxSplitter, VerticalKeyMapper, HorizontalKeyMapper
|
||||
|
||||
|
||||
def make_getter(key):
|
||||
def getter(pair):
|
||||
return pair.metadata[key]
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
def make_coord_getter(c):
|
||||
return {
|
||||
"x1": make_getter(Info.X1),
|
||||
"x2": make_getter(Info.X2),
|
||||
"y1": make_getter(Info.Y1),
|
||||
"y2": make_getter(Info.Y2),
|
||||
}[c]
|
||||
|
||||
|
||||
def make_pair_merger(axis):
|
||||
return {"y": merge_pair_vertically, "x": merge_pair_horizontally}[axis]
|
||||
|
||||
|
||||
def make_group_merger(axis):
|
||||
return {"y": merge_group_vertically, "x": merge_group_horizontally}[axis]
|
||||
|
||||
|
||||
def make_length_getter(dim):
|
||||
return {
|
||||
"width": make_getter(Info.WIDTH),
|
||||
"height": make_getter(Info.HEIGHT),
|
||||
}[dim]
|
||||
|
||||
from test.utils.stitching import BoxSplitter
|
||||
|
||||
x1_getter, y1_getter, x2_getter, y2_getter = map(make_coord_getter, ("x1", "y1", "x2", "y2"))
|
||||
width_getter, height_getter = map(make_length_getter, ("width", "height"))
|
||||
|
||||
|
||||
def merge_metadata_horizontally(m1, m2):
|
||||
m1, m2 = map(HorizontalKeyMapper, [m1, m2])
|
||||
return merge_metadata(m1, m2)
|
||||
|
||||
|
||||
def merge_metadata_vertically(m1, m2):
|
||||
m1, m2 = map(VerticalKeyMapper, [m1, m2])
|
||||
return merge_metadata(m1, m2)
|
||||
|
||||
|
||||
def merge_metadata(m1, m2):
|
||||
|
||||
c1 = min(m1.c1, m2.c1)
|
||||
c2 = max(m1.c2, m2.c2)
|
||||
dim = m1.dim + m2.dim
|
||||
|
||||
merged = deepcopy(m1)
|
||||
merged.dim = dim
|
||||
merged.c1 = c1
|
||||
merged.c2 = c2
|
||||
|
||||
return merged.wrapped
|
||||
|
||||
|
||||
def merge_pair_horizontally(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
||||
metadata_merged = merge_metadata_horizontally(p1.metadata, p2.metadata)
|
||||
image_concatenated = concat_images_horizontally(p1.image, p2.image, metadata_merged)
|
||||
return ImageMetadataPair(image_concatenated, metadata_merged)
|
||||
|
||||
|
||||
def merge_pair_vertically(p1: ImageMetadataPair, p2: ImageMetadataPair):
|
||||
metadata_merged = merge_metadata_vertically(p1.metadata, p2.metadata)
|
||||
image_concatenated = concat_images_vertically(p1.image, p2.image, metadata_merged)
|
||||
return ImageMetadataPair(image_concatenated, metadata_merged)
|
||||
|
||||
|
||||
# def merge_pair(p1, p2):
|
||||
# assert p1.metadata[Info.PAGE_IDX] == p2.metadta[Info.PAGE_IDX]
|
||||
|
||||
|
||||
def concat_images_horizontally(im1: Image, im2: Image, metadata: dict):
|
||||
return concat_images(im1, im2, metadata, 0)
|
||||
|
||||
|
||||
def concat_images_vertically(im1: Image, im2: Image, metadata: dict):
|
||||
return concat_images(im1, im2, metadata, 1)
|
||||
|
||||
|
||||
def concat_images(im1: Image, im2: Image, metadata: dict, axis):
|
||||
|
||||
im_aggr = Image.new(im1.mode, (metadata[Info.WIDTH], metadata[Info.HEIGHT]))
|
||||
|
||||
images = [im1, im2]
|
||||
|
||||
offsets = [0, *[im.size[axis] for im in images]]
|
||||
|
||||
for im, offset in zip(images, offsets):
|
||||
box = (offset, 0) if not axis else (0, offset)
|
||||
im_aggr.paste(im, box=box)
|
||||
|
||||
return im_aggr
|
||||
|
||||
|
||||
def make_merger_aggregator(direction):
|
||||
"""Produces a function f : [H, T1, ... Tn] -> [HTi...Tj, Tk ... Tl] that merges adjacent image-metadata pairs on the
|
||||
head H and aggregates non-adjacent in the tail T.
|
||||
"""
|
||||
|
||||
def merger_aggregator(pairs):
|
||||
def merge_on_head_and_aggregate_in_tail(pairs_aggr: Iterable[ImageMetadataPair], pair: ImageMetadataPair):
|
||||
"""Keeps the image that is being merged with as the head and aggregates non-mergables in the tail."""
|
||||
aggr, non_aggr = juxt(first, rest)(pairs_aggr)
|
||||
if c2_getter(aggr) == c1_getter(pair):
|
||||
aggr = pair_merger(aggr, pair)
|
||||
return aggr, *non_aggr
|
||||
else:
|
||||
return aggr, pair, *non_aggr
|
||||
|
||||
# Requires H to be the least element in image-concatenation direction by c1, since the concatenation happens
|
||||
# only in c1 -> c2 direction.
|
||||
pairs = sorted(pairs, key=c1_getter)
|
||||
aggregation_pair, pairs = juxt(first, rest)(pairs)
|
||||
return list(reduce(merge_on_head_and_aggregate_in_tail, pairs, [aggregation_pair]))
|
||||
|
||||
c1_getter = make_coord_getter(f"{direction}1")
|
||||
c2_getter = make_coord_getter(f"{direction}2")
|
||||
pair_merger = make_pair_merger(direction)
|
||||
|
||||
return merger_aggregator
|
||||
|
||||
|
||||
def merge_group(group, direction):
|
||||
reduce_group = make_merger_aggregator(direction)
|
||||
return until_convergence(reduce_group, group)
|
||||
|
||||
|
||||
def merge_group_horizontally(group):
|
||||
return merge_group(group, "x")
|
||||
|
||||
|
||||
def merge_group_vertically(group):
|
||||
return merge_group(group, "y")
|
||||
|
||||
|
||||
class Stitcher:
|
||||
|
||||
@staticmethod
|
||||
def groupby(pairs, coord_getter):
|
||||
pairs = sorted(pairs, key=coord_getter)
|
||||
return map(compose(list, second), groupby(pairs, coord_getter))
|
||||
|
||||
@staticmethod
|
||||
def other_axis(axis):
|
||||
return "y" if axis == "x" else "x"
|
||||
|
||||
def merge_along_axis(self, pairs, axis):
|
||||
|
||||
def group_pairs_by_c1(pairs):
|
||||
return self.groupby(pairs, c1_getter)
|
||||
|
||||
def group_by_c2(pairs):
|
||||
return self.groupby(pairs, c2_getter)
|
||||
|
||||
def group_pairs_within_groups_by_c2(groups):
|
||||
return map(group_by_c2, groups)
|
||||
|
||||
def merge_groups_along_orthogonal_axis(groups):
|
||||
return map(group_merger, groups)
|
||||
|
||||
c1_getter = make_coord_getter(f"{self.other_axis(axis)}1")
|
||||
c2_getter = make_coord_getter(f"{self.other_axis(axis)}2")
|
||||
group_merger = make_group_merger(axis)
|
||||
|
||||
groups_of_pairs_with_same_c1 = group_pairs_by_c1(pairs)
|
||||
groups_of_groups_of_pairs_with_same_c1_and_c2 = group_pairs_within_groups_by_c2(groups_of_pairs_with_same_c1)
|
||||
groups_of_pairs_with_matching_c1_and_c2 = chain(*groups_of_groups_of_pairs_with_same_c1_and_c2)
|
||||
groups_of_merged_pairs = merge_groups_along_orthogonal_axis(groups_of_pairs_with_matching_c1_and_c2)
|
||||
pairs = chain(*groups_of_merged_pairs)
|
||||
|
||||
return pairs
|
||||
|
||||
def merge_along_both_axes(self, pairs):
|
||||
|
||||
pairs = self.merge_along_axis(pairs, "x")
|
||||
pairs = list(self.merge_along_axis(pairs, "y"))
|
||||
|
||||
return pairs
|
||||
|
||||
def stitch(self, pairs: Iterable[ImageMetadataPair]) -> List[ImageMetadataPair]:
|
||||
return until_convergence(self.merge_along_both_axes, pairs)
|
||||
|
||||
|
||||
def until_convergence(func, *args, **kwargs):
|
||||
for a, b in chunk_iterable(iterate(func, *args, **kwargs), chunk_size=2):
|
||||
if len(a) == len(b):
|
||||
return a
|
||||
|
||||
|
||||
#####################################
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("width", [160])
|
||||
# @pytest.mark.parametrize("height", [90])
|
||||
# @pytest.mark.parametrize("page_width", [int(160 * 1.1)])
|
||||
# @pytest.mark.parametrize("page_height", [int(90 * 1.1)])
|
||||
def test_image_stitcher(patch_image_metadata_pairs, base_patch_metadata, base_patch_image):
|
||||
pair_stitched = Stitcher().stitch(patch_image_metadata_pairs)[0]
|
||||
assert pair_stitched.metadata == base_patch_metadata
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user