diff --git a/cv_analysis/locations.py b/cv_analysis/locations.py index 34d36c2..6d12787 100644 --- a/cv_analysis/locations.py +++ b/cv_analysis/locations.py @@ -11,3 +11,4 @@ TEST_DATA_DIR = TEST_DIR_PATH / "data" TEST_DATA_DIR_DVC = TEST_DIR_PATH / "data.dvc" TEST_DATA_SYNTHESIS_DIR = TEST_DATA_DIR / "synthesis" TEST_PAGE_TEXTURES_DIR = TEST_DATA_SYNTHESIS_DIR / "paper" +TEST_SMILES_FILE = TEST_DATA_SYNTHESIS_DIR / "smiles.csv" diff --git a/cv_analysis/logging.py b/cv_analysis/logging.py index 7bcba7b..8e03143 100644 --- a/cv_analysis/logging.py +++ b/cv_analysis/logging.py @@ -13,7 +13,7 @@ debug_logger = loguru.logger debug_logger.add( sink=sys.stderr, format="{time:YYYY-MM-DD at HH:mm:ss} | {level: <8} | {name}: {message}", - level="TRACE", + level="DEBUG", ) dev_logger = loguru.logger @@ -31,6 +31,9 @@ prod_logger.add( enqueue=True, ) +# logger.remove() +# logger.add(sink=sys.stderr, level="DEBUG", enqueue=True) + def __log(logger, level: str, enters=True, exits=True) -> Callable: print_func = get_print_func(logger, level) diff --git a/cv_analysis/utils/rectangle.py b/cv_analysis/utils/rectangle.py index 1086a36..1f09bc1 100644 --- a/cv_analysis/utils/rectangle.py +++ b/cv_analysis/utils/rectangle.py @@ -51,6 +51,10 @@ class Rectangle: def coords(self): return [self.x1, self.y1, self.x2, self.y2] + @property + def size(self): + return self.width, self.height + def __hash__(self): return hash((self.x1, self.y1, self.x2, self.y2)) diff --git a/poetry.lock b/poetry.lock index 0159106..5f7c42f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2229,6 +2229,18 @@ opencv-python-headless = ">=4.0.1" scikit-learn = ">=0.19.1" typing-extensions = "*" +[[package]] +name = "rdkit" +version = "2022.9.4" +description = "A collection of chemoinformatics and machine-learning software written in C++ and Python" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +numpy = "*" +Pillow = "*" + [[package]] name = "requests" version = "2.28.1" @@ -2862,7 +2874,7 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [metadata] lock-version = "1.1" python-versions = "~3.8" -content-hash = "117d9fceef40b37d126a7c2e47125c74373307e1f9a0f0ae82fc9e5e21295f25" +content-hash = "45539080e4964adfec7aad366e5ae67e25659afa188f4da882f93fe4f313fe36" [metadata.files] aiohttp = [ @@ -4500,6 +4512,32 @@ qudida = [ {file = "qudida-0.0.4-py3-none-any.whl", hash = "sha256:4519714c40cd0f2e6c51e1735edae8f8b19f4efe1f33be13e9d644ca5f736dd6"}, {file = "qudida-0.0.4.tar.gz", hash = "sha256:db198e2887ab0c9aa0023e565afbff41dfb76b361f85fd5e13f780d75ba18cc8"}, ] +rdkit = [ + {file = "rdkit-2022.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0ef8f08dc0dad7fa6b87616b412ed7a044e98469714ba269e3a4cb46e9d903a4"}, + {file = "rdkit-2022.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bc2a73cb07197870cfd9e6aac8b58468375f6c6458aced3a12a232d1cd52c81e"}, + {file = "rdkit-2022.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63f3d7c86fc6263a0aae189d37500f66ccb9f1a5814e4eb29bcd9c76204d4de5"}, + {file = "rdkit-2022.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fe5c858d844d31824d8974385678446d88ef3c5d62dd442af04344ea52853be"}, + {file = "rdkit-2022.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:572bef53768616fb35e62e9bcdd19c1692676c0a5f736845cce9b337c7b91a71"}, + {file = "rdkit-2022.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a98c178029a2ead970ae61bb657f8a38592af68bc74fb5bbfbd0752d40311ea5"}, + {file = "rdkit-2022.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7ccaa50cab089bd27895f1347d9dddb4e5c04548cdbd0bd4b85f01d38281bff6"}, + {file = "rdkit-2022.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d1a4a95b543ec627b7d312bb5cf806e4cbce6c07d00ca0ab180f6b91c858111"}, + {file = "rdkit-2022.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b55b2ec43b663b360e5787db1301336f6c48cdd361d72c7019cfa3989a8b6638"}, + {file = "rdkit-2022.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:75ebacd4a8f2abf634cf3727cc9da20c194bc8466a3c7f7a15a6c2b90d222850"}, + {file = "rdkit-2022.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:15641dbe80c82d95b04b7f80015bc3b08c634e8036bf3fe17f6fe84bbdbd3e3d"}, + {file = "rdkit-2022.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bc0872757e8e841ee6da41d781b4036de4c8b1f731fdbfc754f3ceb99469ca9"}, + {file = "rdkit-2022.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebfdcb0459953909c0c9ab37fa99546e60c6813b1ad0056ded3b9eb0761d5def"}, + {file = "rdkit-2022.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:c563516332c97927652d2aad866afb51b5d3b35bc0a18ae715666f7c757f41b4"}, + {file = "rdkit-2022.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:07edd250e909205d618da40f100ca9a500bb986bbf8db155159dd365178f1756"}, + {file = "rdkit-2022.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8f3e0a81dad936c442cb6cebe9690943176848f1250187aa3676f3392303c1d6"}, + {file = "rdkit-2022.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5d9a7259d5eb9d0e78f2d7188969548cb7c8da61161eb39db9a8624e1ac1157"}, + {file = "rdkit-2022.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8331812ff56f84f24689ab1ba8307859a3761b629d67c8ca6a56f3034999801"}, + {file = "rdkit-2022.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:0cf329bafca28ddd56699e91116abfa0cf5baa56a47c53070a7ef1cb51c11f6a"}, + {file = "rdkit-2022.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d6678b5b7ffa7a0ad1c57791badddb89778a52cf5abc537e7e1446795ac2830a"}, + {file = "rdkit-2022.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7376b7c936abaea89cfc02e33464d173c9b8065f0f93a07a5c625af09d53f85f"}, + {file = "rdkit-2022.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d176b7d7fbe68c0c55658001f4ac9bb4095f4e56ff83326cd0d692778196a99d"}, + {file = "rdkit-2022.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ea471f5469b32f27854a8ee604a1983b0d866cad511226004986dac6c4d9a13"}, + {file = "rdkit-2022.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:4934e62869c7ba0b65618df3cbad2b35291357bad1d0e005b558cad528f37f86"}, +] requests = [ {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, diff --git a/pyproject.toml b/pyproject.toml index 0d1bb6b..4db4d92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ PyMuPDF = "^1.19.6" pdf2img = {git = "ssh://git@git.iqser.com:2222/rr/pdf2image.git", branch = "master"} pyinfra = {git = "ssh://git@git.iqser.com:2222/rr/pyinfra.git", branch = "master"} loguru = "^0.6.0" +rdkit = "^2022.9.4" [tool.poetry.group.build.dependencies] pytest = "^7.0.1" diff --git a/synthesis/formula.py b/synthesis/formula.py index fdaa144..6aa7cdb 100644 --- a/synthesis/formula.py +++ b/synthesis/formula.py @@ -1,16 +1,129 @@ -import argparse +# Draw molecular structures from smiles. Adapted from https://github.com/neeraj-j/molecules +from itertools import islice +from typing import List, Iterable -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument() - args = parser.parse_args() - - return args - - -def main(args): - pass - +import numpy as np +import pandas as pd +from PIL.Image import Image +from funcy import first, retry, keep +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit.Chem import Draw +from rdkit.Chem import FunctionalGroups -if __name__ == "__main__": - main(parse_args()) \ No newline at end of file +from cv_analysis.locations import TEST_SMILES_FILE +from cv_analysis.logging import debug_log, logger + + +class StructuralFormulaImageGenerator: + def __init__(self, width=None, height=None): + self.width = width + self.height = height + + self.templates = collect_templates() + self.functional_groups = self.templates.keys() + + @debug_log() + def generate_images_from_smiles(self, smiles: List[str], max_images_per_functional_group=1) -> Iterable[Image]: + yield from self.generate_images_for_functional_groups( + smiles, + max_images_per_functional_group=max_images_per_functional_group, + ) + + @debug_log() + def generate_images_for_functional_groups(self, smiles: List[str], max_images_per_functional_group): + for functional_group in self.functional_groups: + smiles = iter(smiles) + g = self.generate_images_for_functional_group(smiles, functional_group) + yield from islice(keep(g), max_images_per_functional_group) + + def generate_images_for_functional_group(self, smiles: Iterable[str], functional_group: str): + try: + yield from self.__generate_images_for_functional_group(smiles, functional_group) + except ValueError: + pass + + @debug_log() + @retry(10, errors=ValueError) + def __generate_images_for_functional_group(self, smiles: Iterable[str], functional_group: str): + + AllChem.Compute2DCoords(self.templates[functional_group]) + + for smile in smiles: + try: + image = self.make_image(smile, functional_group) + yield image + except ValueError: # SMILE does not match functional group + raise + + @debug_log() + def make_image(self, smile: str, functional_group: str): + mol = Chem.MolFromSmiles(smile) + AllChem.GenerateDepictionMatching2DStructure(mol, self.templates[functional_group]) + + side_length = np.random.randint(70, 400) + width = self.width or side_length + height = self.height or side_length + + image: Image = Draw.MolToImage( + mol, + size=(width, height), + kekulize=flip_a_coin(), + wedgeBonds=flip_a_coin(), + ) + image.putalpha(255) + return image + + +@debug_log() +def flip_a_coin(): + return bool(np.random.randint(0, 2)) + + +@debug_log() +def collect_templates(): + functional_groups = FunctionalGroups.BuildFuncGroupHierarchy() + group_name_2_pattern = dict(stream_label_pattern_pairs(functional_groups)) + return group_name_2_pattern + + +@debug_log() +def stream_label_pattern_pairs(functional_groups): + for functional_group in functional_groups: + yield functional_group.label, functional_group.pattern + yield from stream_label_pattern_pairs(functional_group.children) + + +@debug_log() +def generate_image_of_structural_formula(smiles_file=None, size=None): + """Generate images of formulas from SMILE encoded formulas. + + Args: + smiles_file: CSV file with column "smiles". Each row contains a SMILE encoded formula. + size: width, height + + Returns: + PIL.Image.Image: Image of a formula. + """ + logger.info(f"Generating structural formula images from {smiles_file}") + return first(generate_images_of_structural_formulas(smiles_file, size=size)) + + +@debug_log() +def generate_images_of_structural_formulas(smiles_file=None, size=None): + """Generate an image of a formula from SMILE encoded formulas. + + Args: + smiles_file: CSV file with column "smiles". Each row contains a SMILE encoded formula. + size: width, height + + Yields: + PIL.Image.Image: Image of a formula. + """ + size = size or (None, None) + smiles_file = smiles_file or TEST_SMILES_FILE + smiles = pd.read_csv(smiles_file).sample(frac=1).drop_duplicates().smiles + yield from StructuralFormulaImageGenerator(*size).generate_images_from_smiles(smiles) + + +# generate_image_of_structural_formula().show() diff --git a/synthesis/segment/plot.py b/synthesis/segment/plot.py index f21c7f6..82698e3 100644 --- a/synthesis/segment/plot.py +++ b/synthesis/segment/plot.py @@ -11,6 +11,7 @@ from matplotlib.colors import ListedColormap from cv_analysis.utils.geometric import is_square_like, is_wide, is_tall from cv_analysis.utils.image_operations import superimpose from cv_analysis.utils.rectangle import Rectangle +from synthesis.formula import generate_image_of_structural_formula from synthesis.randomization import rnd, probably, maybe from synthesis.segment.random_content_rectangle import RandomContentRectangle from synthesis.text.text import generate_random_words @@ -39,6 +40,7 @@ class RandomPlot(RandomContentRectangle): self.generate_random_histogram, self.generate_random_pie_chart, self.generate_random_heat_map, + self.generate_random_structural_formula ] ) elif is_wide(rectangle): @@ -141,6 +143,10 @@ class RandomPlot(RandomContentRectangle): plot_kwargs=self.generate_plot_kwargs(keywords=["a"]), ) + def generate_random_structural_formula(self, rectangle: Rectangle): + image = generate_image_of_structural_formula(size=rectangle.size) + self.content = image if not self.content else superimpose(self.content, image) + def generate_plot_kwargs(self, keywords=None): kwargs = { diff --git a/test/conftest.py b/test/conftest.py index a068821..d0bc1e3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -6,6 +6,7 @@ pytest_plugins = [ "test.fixtures.table_parsing", "test.fixtures.figure_detection", "test.fixtures.data", + "test.fixtures.formula", "test.fixtures.page_generation.page", ] diff --git a/test/page_generation_test.py b/test/page_generation_test.py index 9d56011..6d8024d 100644 --- a/test/page_generation_test.py +++ b/test/page_generation_test.py @@ -15,5 +15,5 @@ def test_blank_page(page_with_content): def draw_boxes(page: Image, boxes: Iterable[Rectangle]): from cv_analysis.utils.drawing import draw_rectangles - page = draw_rectangles(page, boxes, filled=False, annotate=True) + # page = draw_rectangles(page, boxes, filled=False, annotate=True) show_image(page, backend="pil")