diff --git a/test/fixtures/page_generation/page.py b/test/fixtures/page_generation/page.py index 4d41d14..f8b568a 100644 --- a/test/fixtures/page_generation/page.py +++ b/test/fixtures/page_generation/page.py @@ -103,6 +103,7 @@ from funcy import ( keep, repeatedly, mapcat, + lmapcat, ) from cv_analysis.locations import TEST_PAGE_TEXTURES_DIR @@ -259,9 +260,7 @@ def blur(image: np.ndarray): def normalize_image_format_to_array(image: Image_t): - if isinstance(image, Image.Image): - return np.array(image) - return image + return np.array(image) if isinstance(image, Image.Image) else image def normalize_image_format_to_pil(image: Image_t): @@ -410,10 +409,12 @@ class ContentGenerator: text_boxes = lmap(generate_random_text_block, every_nth(2, text_boxes)) tables_1 = lmap(generate_recursive_random_table, every_nth(2, text_boxes[1:])) - plots = lmap(generate_random_plot, every_nth(2, figure_boxes)) + + # TODO: Refactor: Figures should be their own class + plots, captions = map(list, zip(*map(generate_random_figure, every_nth(2, figure_boxes)))) tables_2 = lmap(generate_recursive_random_table, every_nth(2, figure_boxes[1:])) - boxes = text_boxes + plots + tables_1 + tables_2 + boxes = text_boxes + plots + captions + tables_1 + tables_2 boxes = remove_included(boxes) boxes = remove_overlapping(boxes) @@ -428,6 +429,21 @@ def every_nth(n, iterable): return itertools.islice(iterable, 0, None, n) +def generate_random_figure(rectangle: Rectangle): + # assert rectangle.height / rectangle.width < 0.7, "Figure is too wide to add a caption." + figure_box, caption_box = split_into_figure_and_caption(rectangle) + figure_box = generate_random_plot(figure_box) + caption_box = generate_random_text_block(caption_box) + return figure_box, caption_box + + +def split_into_figure_and_caption(rectangle: Rectangle): + split_point = random.uniform(0.5, 0.9) + figure_box = Rectangle(rectangle.x1, rectangle.y1, rectangle.x2, rectangle.y1 + rectangle.height * split_point) + caption_box = Rectangle(rectangle.x1, rectangle.y1 + rectangle.height * split_point, rectangle.x2, rectangle.y2) + return figure_box, caption_box + + def generate_random_plot(rectangle: Rectangle) -> ContentRectangle: block = RandomPlot(*rectangle.coords) block.content = rectangle.content if isinstance(rectangle, ContentRectangle) else None # TODO: Refactor