diff --git a/cv_analysis/locations.py b/cv_analysis/locations.py
index 34d36c2..6d12787 100644
--- a/cv_analysis/locations.py
+++ b/cv_analysis/locations.py
@@ -11,3 +11,4 @@ TEST_DATA_DIR = TEST_DIR_PATH / "data"
TEST_DATA_DIR_DVC = TEST_DIR_PATH / "data.dvc"
TEST_DATA_SYNTHESIS_DIR = TEST_DATA_DIR / "synthesis"
TEST_PAGE_TEXTURES_DIR = TEST_DATA_SYNTHESIS_DIR / "paper"
+TEST_SMILES_FILE = TEST_DATA_SYNTHESIS_DIR / "smiles.csv"
diff --git a/cv_analysis/logging.py b/cv_analysis/logging.py
index 7bcba7b..8e03143 100644
--- a/cv_analysis/logging.py
+++ b/cv_analysis/logging.py
@@ -13,7 +13,7 @@ debug_logger = loguru.logger
debug_logger.add(
sink=sys.stderr,
format="{time:YYYY-MM-DD at HH:mm:ss} | {level: <8} | {name}: {message}",
- level="TRACE",
+ level="DEBUG",
)
dev_logger = loguru.logger
@@ -31,6 +31,9 @@ prod_logger.add(
enqueue=True,
)
+# logger.remove()
+# logger.add(sink=sys.stderr, level="DEBUG", enqueue=True)
+
def __log(logger, level: str, enters=True, exits=True) -> Callable:
print_func = get_print_func(logger, level)
diff --git a/cv_analysis/utils/rectangle.py b/cv_analysis/utils/rectangle.py
index 1086a36..1f09bc1 100644
--- a/cv_analysis/utils/rectangle.py
+++ b/cv_analysis/utils/rectangle.py
@@ -51,6 +51,10 @@ class Rectangle:
def coords(self):
return [self.x1, self.y1, self.x2, self.y2]
+ @property
+ def size(self):
+ return self.width, self.height
+
def __hash__(self):
return hash((self.x1, self.y1, self.x2, self.y2))
diff --git a/poetry.lock b/poetry.lock
index 0159106..5f7c42f 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2229,6 +2229,18 @@ opencv-python-headless = ">=4.0.1"
scikit-learn = ">=0.19.1"
typing-extensions = "*"
+[[package]]
+name = "rdkit"
+version = "2022.9.4"
+description = "A collection of chemoinformatics and machine-learning software written in C++ and Python"
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+numpy = "*"
+Pillow = "*"
+
[[package]]
name = "requests"
version = "2.28.1"
@@ -2862,7 +2874,7 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
[metadata]
lock-version = "1.1"
python-versions = "~3.8"
-content-hash = "117d9fceef40b37d126a7c2e47125c74373307e1f9a0f0ae82fc9e5e21295f25"
+content-hash = "45539080e4964adfec7aad366e5ae67e25659afa188f4da882f93fe4f313fe36"
[metadata.files]
aiohttp = [
@@ -4500,6 +4512,32 @@ qudida = [
{file = "qudida-0.0.4-py3-none-any.whl", hash = "sha256:4519714c40cd0f2e6c51e1735edae8f8b19f4efe1f33be13e9d644ca5f736dd6"},
{file = "qudida-0.0.4.tar.gz", hash = "sha256:db198e2887ab0c9aa0023e565afbff41dfb76b361f85fd5e13f780d75ba18cc8"},
]
+rdkit = [
+ {file = "rdkit-2022.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0ef8f08dc0dad7fa6b87616b412ed7a044e98469714ba269e3a4cb46e9d903a4"},
+ {file = "rdkit-2022.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bc2a73cb07197870cfd9e6aac8b58468375f6c6458aced3a12a232d1cd52c81e"},
+ {file = "rdkit-2022.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63f3d7c86fc6263a0aae189d37500f66ccb9f1a5814e4eb29bcd9c76204d4de5"},
+ {file = "rdkit-2022.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fe5c858d844d31824d8974385678446d88ef3c5d62dd442af04344ea52853be"},
+ {file = "rdkit-2022.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:572bef53768616fb35e62e9bcdd19c1692676c0a5f736845cce9b337c7b91a71"},
+ {file = "rdkit-2022.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a98c178029a2ead970ae61bb657f8a38592af68bc74fb5bbfbd0752d40311ea5"},
+ {file = "rdkit-2022.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7ccaa50cab089bd27895f1347d9dddb4e5c04548cdbd0bd4b85f01d38281bff6"},
+ {file = "rdkit-2022.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d1a4a95b543ec627b7d312bb5cf806e4cbce6c07d00ca0ab180f6b91c858111"},
+ {file = "rdkit-2022.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b55b2ec43b663b360e5787db1301336f6c48cdd361d72c7019cfa3989a8b6638"},
+ {file = "rdkit-2022.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:75ebacd4a8f2abf634cf3727cc9da20c194bc8466a3c7f7a15a6c2b90d222850"},
+ {file = "rdkit-2022.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:15641dbe80c82d95b04b7f80015bc3b08c634e8036bf3fe17f6fe84bbdbd3e3d"},
+ {file = "rdkit-2022.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bc0872757e8e841ee6da41d781b4036de4c8b1f731fdbfc754f3ceb99469ca9"},
+ {file = "rdkit-2022.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebfdcb0459953909c0c9ab37fa99546e60c6813b1ad0056ded3b9eb0761d5def"},
+ {file = "rdkit-2022.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:c563516332c97927652d2aad866afb51b5d3b35bc0a18ae715666f7c757f41b4"},
+ {file = "rdkit-2022.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:07edd250e909205d618da40f100ca9a500bb986bbf8db155159dd365178f1756"},
+ {file = "rdkit-2022.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8f3e0a81dad936c442cb6cebe9690943176848f1250187aa3676f3392303c1d6"},
+ {file = "rdkit-2022.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5d9a7259d5eb9d0e78f2d7188969548cb7c8da61161eb39db9a8624e1ac1157"},
+ {file = "rdkit-2022.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8331812ff56f84f24689ab1ba8307859a3761b629d67c8ca6a56f3034999801"},
+ {file = "rdkit-2022.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:0cf329bafca28ddd56699e91116abfa0cf5baa56a47c53070a7ef1cb51c11f6a"},
+ {file = "rdkit-2022.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d6678b5b7ffa7a0ad1c57791badddb89778a52cf5abc537e7e1446795ac2830a"},
+ {file = "rdkit-2022.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7376b7c936abaea89cfc02e33464d173c9b8065f0f93a07a5c625af09d53f85f"},
+ {file = "rdkit-2022.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d176b7d7fbe68c0c55658001f4ac9bb4095f4e56ff83326cd0d692778196a99d"},
+ {file = "rdkit-2022.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ea471f5469b32f27854a8ee604a1983b0d866cad511226004986dac6c4d9a13"},
+ {file = "rdkit-2022.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:4934e62869c7ba0b65618df3cbad2b35291357bad1d0e005b558cad528f37f86"},
+]
requests = [
{file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"},
{file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"},
diff --git a/pyproject.toml b/pyproject.toml
index 0d1bb6b..4db4d92 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,6 +31,7 @@ PyMuPDF = "^1.19.6"
pdf2img = {git = "ssh://git@git.iqser.com:2222/rr/pdf2image.git", branch = "master"}
pyinfra = {git = "ssh://git@git.iqser.com:2222/rr/pyinfra.git", branch = "master"}
loguru = "^0.6.0"
+rdkit = "^2022.9.4"
[tool.poetry.group.build.dependencies]
pytest = "^7.0.1"
diff --git a/synthesis/formula.py b/synthesis/formula.py
index fdaa144..6aa7cdb 100644
--- a/synthesis/formula.py
+++ b/synthesis/formula.py
@@ -1,16 +1,129 @@
-import argparse
+# Draw molecular structures from smiles. Adapted from https://github.com/neeraj-j/molecules
+from itertools import islice
+from typing import List, Iterable
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument()
- args = parser.parse_args()
-
- return args
-
-
-def main(args):
- pass
-
+import numpy as np
+import pandas as pd
+from PIL.Image import Image
+from funcy import first, retry, keep
+from rdkit import Chem
+from rdkit.Chem import AllChem
+from rdkit.Chem import Draw
+from rdkit.Chem import FunctionalGroups
-if __name__ == "__main__":
- main(parse_args())
\ No newline at end of file
+from cv_analysis.locations import TEST_SMILES_FILE
+from cv_analysis.logging import debug_log, logger
+
+
+class StructuralFormulaImageGenerator:
+ def __init__(self, width=None, height=None):
+ self.width = width
+ self.height = height
+
+ self.templates = collect_templates()
+ self.functional_groups = self.templates.keys()
+
+ @debug_log()
+ def generate_images_from_smiles(self, smiles: List[str], max_images_per_functional_group=1) -> Iterable[Image]:
+ yield from self.generate_images_for_functional_groups(
+ smiles,
+ max_images_per_functional_group=max_images_per_functional_group,
+ )
+
+ @debug_log()
+ def generate_images_for_functional_groups(self, smiles: List[str], max_images_per_functional_group):
+ for functional_group in self.functional_groups:
+ smiles = iter(smiles)
+ g = self.generate_images_for_functional_group(smiles, functional_group)
+ yield from islice(keep(g), max_images_per_functional_group)
+
+ def generate_images_for_functional_group(self, smiles: Iterable[str], functional_group: str):
+ try:
+ yield from self.__generate_images_for_functional_group(smiles, functional_group)
+ except ValueError:
+ pass
+
+ @debug_log()
+ @retry(10, errors=ValueError)
+ def __generate_images_for_functional_group(self, smiles: Iterable[str], functional_group: str):
+
+ AllChem.Compute2DCoords(self.templates[functional_group])
+
+ for smile in smiles:
+ try:
+ image = self.make_image(smile, functional_group)
+ yield image
+ except ValueError: # SMILE does not match functional group
+ raise
+
+ @debug_log()
+ def make_image(self, smile: str, functional_group: str):
+ mol = Chem.MolFromSmiles(smile)
+ AllChem.GenerateDepictionMatching2DStructure(mol, self.templates[functional_group])
+
+ side_length = np.random.randint(70, 400)
+ width = self.width or side_length
+ height = self.height or side_length
+
+ image: Image = Draw.MolToImage(
+ mol,
+ size=(width, height),
+ kekulize=flip_a_coin(),
+ wedgeBonds=flip_a_coin(),
+ )
+ image.putalpha(255)
+ return image
+
+
+@debug_log()
+def flip_a_coin():
+ return bool(np.random.randint(0, 2))
+
+
+@debug_log()
+def collect_templates():
+ functional_groups = FunctionalGroups.BuildFuncGroupHierarchy()
+ group_name_2_pattern = dict(stream_label_pattern_pairs(functional_groups))
+ return group_name_2_pattern
+
+
+@debug_log()
+def stream_label_pattern_pairs(functional_groups):
+ for functional_group in functional_groups:
+ yield functional_group.label, functional_group.pattern
+ yield from stream_label_pattern_pairs(functional_group.children)
+
+
+@debug_log()
+def generate_image_of_structural_formula(smiles_file=None, size=None):
+ """Generate images of formulas from SMILE encoded formulas.
+
+ Args:
+ smiles_file: CSV file with column "smiles". Each row contains a SMILE encoded formula.
+ size: width, height
+
+ Returns:
+ PIL.Image.Image: Image of a formula.
+ """
+ logger.info(f"Generating structural formula images from {smiles_file}")
+ return first(generate_images_of_structural_formulas(smiles_file, size=size))
+
+
+@debug_log()
+def generate_images_of_structural_formulas(smiles_file=None, size=None):
+ """Generate an image of a formula from SMILE encoded formulas.
+
+ Args:
+ smiles_file: CSV file with column "smiles". Each row contains a SMILE encoded formula.
+ size: width, height
+
+ Yields:
+ PIL.Image.Image: Image of a formula.
+ """
+ size = size or (None, None)
+ smiles_file = smiles_file or TEST_SMILES_FILE
+ smiles = pd.read_csv(smiles_file).sample(frac=1).drop_duplicates().smiles
+ yield from StructuralFormulaImageGenerator(*size).generate_images_from_smiles(smiles)
+
+
+# generate_image_of_structural_formula().show()
diff --git a/synthesis/segment/plot.py b/synthesis/segment/plot.py
index f21c7f6..82698e3 100644
--- a/synthesis/segment/plot.py
+++ b/synthesis/segment/plot.py
@@ -11,6 +11,7 @@ from matplotlib.colors import ListedColormap
from cv_analysis.utils.geometric import is_square_like, is_wide, is_tall
from cv_analysis.utils.image_operations import superimpose
from cv_analysis.utils.rectangle import Rectangle
+from synthesis.formula import generate_image_of_structural_formula
from synthesis.randomization import rnd, probably, maybe
from synthesis.segment.random_content_rectangle import RandomContentRectangle
from synthesis.text.text import generate_random_words
@@ -39,6 +40,7 @@ class RandomPlot(RandomContentRectangle):
self.generate_random_histogram,
self.generate_random_pie_chart,
self.generate_random_heat_map,
+ self.generate_random_structural_formula
]
)
elif is_wide(rectangle):
@@ -141,6 +143,10 @@ class RandomPlot(RandomContentRectangle):
plot_kwargs=self.generate_plot_kwargs(keywords=["a"]),
)
+ def generate_random_structural_formula(self, rectangle: Rectangle):
+ image = generate_image_of_structural_formula(size=rectangle.size)
+ self.content = image if not self.content else superimpose(self.content, image)
+
def generate_plot_kwargs(self, keywords=None):
kwargs = {
diff --git a/test/conftest.py b/test/conftest.py
index a068821..d0bc1e3 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -6,6 +6,7 @@ pytest_plugins = [
"test.fixtures.table_parsing",
"test.fixtures.figure_detection",
"test.fixtures.data",
+ "test.fixtures.formula",
"test.fixtures.page_generation.page",
]
diff --git a/test/page_generation_test.py b/test/page_generation_test.py
index 9d56011..6d8024d 100644
--- a/test/page_generation_test.py
+++ b/test/page_generation_test.py
@@ -15,5 +15,5 @@ def test_blank_page(page_with_content):
def draw_boxes(page: Image, boxes: Iterable[Rectangle]):
from cv_analysis.utils.drawing import draw_rectangles
- page = draw_rectangles(page, boxes, filled=False, annotate=True)
+ # page = draw_rectangles(page, boxes, filled=False, annotate=True)
show_image(page, backend="pil")