cv-analysis-service/synthesis/content_generator.py
Matthias Bisping 363d04ce5d Formatting
2023-02-28 14:03:02 +01:00

88 lines
2.7 KiB
Python

import itertools
from typing import List, Iterable
from PIL import Image
from funcy import lsplit, lfilter, mapcat
from cv_analysis.logging import logger
from cv_analysis.utils import every_nth, zipmap
from cv_analysis.utils.geometric import is_square_like
from cv_analysis.utils.merging import merge_related_rectangles
from cv_analysis.utils.postprocessing import remove_included, remove_overlapping
from cv_analysis.utils.rectangle import Rectangle
from synthesis.randomization import rnd
from synthesis.segment.content_rectangle import ContentRectangle
from synthesis.segment.recursive_content_rectangle import RecursiveContentRectangle
from synthesis.segment.segments import (
generate_random_text_block,
generate_recursive_random_table_with_caption,
generate_random_plot_with_caption,
)
class ContentGenerator:
def __init__(self):
self.constrain_layouts = True
def __call__(self, boxes: List[Rectangle]) -> Image:
rnd.shuffle(boxes)
figure_boxes, text_boxes = lsplit(is_square_like, boxes)
if self.constrain_layouts:
figure_boxes = merge_related_rectangles(figure_boxes)
figure_boxes = lfilter(is_square_like, figure_boxes)
text_boxes = merge_related_rectangles(text_boxes)
boxes = list(
itertools.chain(
map(generate_random_text_block, every_nth(2, text_boxes)),
*zipmap(generate_recursive_random_table_with_caption, every_nth(2, text_boxes[1:])),
*zipmap(generate_recursive_random_table_with_caption, every_nth(2, figure_boxes)),
*zipmap(generate_random_plot_with_caption, every_nth(2, figure_boxes[1:])),
)
)
if self.constrain_layouts:
boxes = remove_included(boxes)
boxes = remove_overlapping(boxes)
boxes = list(unpack_boxes(boxes))
for b in boxes:
logger.trace(f"Generated {b}")
return boxes
def unpack_boxes(boxes: Iterable[ContentRectangle]) -> Iterable[ContentRectangle]:
for box in boxes:
yield box
yield from mapcat(__unpack_box_rec, boxes)
def __unpack_box_rec(box: ContentRectangle) -> Iterable[ContentRectangle]:
children = box.accept(BoxChildrenVisitor())
def is_a_leaf():
return not children
def is_an_internal_node():
return children
if is_an_internal_node():
yield from mapcat(__unpack_box_rec, children)
elif is_a_leaf():
yield box
else:
raise ValueError("This should not happen")
class BoxChildrenVisitor:
def visit_content_rectangle(self, _box: ContentRectangle):
return []
def visit_recursive_content_rectangle(self, box: RecursiveContentRectangle):
return box.children