2023-02-01 18:56:16 +01:00

193 lines
6.0 KiB
Python

import io
import random
from functools import lru_cache, partial
import loguru
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from cv_analysis.utils.geometric import is_square_like, is_wide, is_tall
from cv_analysis.utils.image_operations import superimpose
from cv_analysis.utils.rectangle import Rectangle
from synthesis.random import rnd, probably, maybe
from synthesis.segment.random_content_rectangle import RandomContentRectangle
from synthesis.text.text import generate_random_words
class RandomPlot(RandomContentRectangle):
def __init__(self, x1, y1, x2, y2, seed=None):
super().__init__(x1, y1, x2, y2, seed=seed)
self.cmap = pick_colormap()
def __call__(self, *args, **kwargs):
pass
def generate_random_plot(self, rectangle: Rectangle):
if is_square_like(rectangle):
plt_fn = rnd.choice(
[
self.generate_random_line_plot,
self.generate_random_bar_plot,
self.generate_random_scatter_plot,
self.generate_random_histogram,
self.generate_random_pie_chart,
]
)
elif is_wide(rectangle):
plt_fn = rnd.choice(
[
self.generate_random_line_plot,
self.generate_random_histogram,
self.generate_random_bar_plot,
]
)
elif is_tall(rectangle):
plt_fn = rnd.choice(
[
self.generate_random_bar_plot,
self.generate_random_histogram,
]
)
else:
plt_fn = self.generate_random_scatter_plot
plt_fn(rectangle)
def generate_random_bar_plot(self, rectangle: Rectangle):
x = sorted(np.random.randint(low=1, high=11, size=5))
y = np.random.randint(low=1, high=11, size=5)
bar_fn = partial(
plt.bar,
log=random.choice([True, False]),
)
self.__generate_random_plot(bar_fn, rectangle, x, y)
def generate_random_line_plot(self, rectangle: Rectangle):
f = rnd.choice([np.sin, np.cos, np.tan, np.exp, np.log, np.sqrt, np.square])
x = np.linspace(0, 10, 100)
y = f(x)
plot_fn = partial(
plt.plot,
)
self.__generate_random_plot(plot_fn, rectangle, x, y)
def generate_random_scatter_plot(self, rectangle: Rectangle):
n = rnd.randint(10, 40)
x = np.random.normal(size=n)
y = np.random.normal(size=n)
scatter_fn = partial(
plt.scatter,
cmap=self.cmap,
marker=rnd.choice(["o", "*", "+", "x"]),
)
self.__generate_random_plot(scatter_fn, rectangle, x, y)
def generate_random_histogram(self, rectangle: Rectangle):
x = np.random.normal(size=100)
hist_fn = partial(
plt.hist,
orientation=random.choice(["horizontal", "vertical"]),
histtype=random.choice(["bar", "barstacked", "step", "stepfilled"]),
log=random.choice([True, False]),
stacked=random.choice([True, False]),
density=random.choice([True, False]),
cumulative=random.choice([True, False]),
)
self.__generate_random_plot(hist_fn, rectangle, x, random.randint(5, 20))
def generate_random_pie_chart(self, rectangle: Rectangle):
n = random.randint(3, 7)
x = np.random.uniform(size=n)
pie_fn = partial(
plt.pie,
shadow=True,
startangle=90,
pctdistance=0.85,
labeldistance=1.1,
colors=self.cmap(np.linspace(0, 1, 10)),
)
self.__generate_random_plot(
pie_fn,
rectangle,
x,
np.random.uniform(0, 0.1, size=n),
plot_kwargs=self.generate_plot_kwargs(keywords=["a"]),
)
def generate_plot_kwargs(self, keywords=None):
kwargs = {
"color": rnd.choice(self.cmap.colors),
"linestyle": rnd.choice(["-", "--", "-.", ":"]),
"linewidth": rnd.uniform(1, 4),
}
return kwargs if not keywords else {k: v for k, v in kwargs.items() if k in keywords}
def __generate_random_plot(self, plot_fn, rectangle: Rectangle, x, y, plot_kwargs=None):
plot_kwargs = self.generate_plot_kwargs() if plot_kwargs is None else plot_kwargs
fig, ax = plt.subplots()
fig.set_size_inches(rectangle.width / 100, rectangle.height / 100)
fig.tight_layout(pad=0)
plot_fn(x, y, **plot_kwargs)
ax.set_facecolor("none")
probably() and ax.set_title(generate_random_words(1, 3))
# disable axes at random
maybe() and ax.set_xticks([])
maybe() and ax.set_yticks([])
maybe() and ax.set_xticklabels([])
maybe() and ax.set_yticklabels([])
maybe() and ax.set_xlabel("")
maybe() and ax.set_ylabel("")
maybe() and ax.set_title("")
maybe() and ax.set_frame_on(False)
# remove spines at random
maybe() and (ax.spines["top"].set_visible(False) or ax.spines["right"].set_visible(False))
image = dump_plt_to_image(rectangle)
assert image.mode == "RGBA"
self.content = image if not self.content else superimpose(self.content, image)
@lru_cache(maxsize=None)
def pick_colormap() -> ListedColormap:
cmap_name = rnd.choice(
[
"viridis",
"plasma",
"inferno",
"magma",
"cividis",
],
)
loguru.logger.info(f"Using colormap {cmap_name}")
cmap = plt.get_cmap(cmap_name)
return cmap
def dump_plt_to_image(rectangle):
buf = io.BytesIO()
plt.savefig(buf, format="png", transparent=True)
buf.seek(0)
image = Image.open(buf)
image = image.resize((rectangle.width, rectangle.height))
buf.close()
plt.close()
return image