130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
# Draw molecular structures from smiles. Adapted from https://github.com/neeraj-j/molecules
|
|
from itertools import islice
|
|
from typing import List, Iterable
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from PIL.Image import Image
|
|
from funcy import first, retry, keep
|
|
from rdkit import Chem
|
|
from rdkit.Chem import AllChem
|
|
from rdkit.Chem import Draw
|
|
from rdkit.Chem import FunctionalGroups
|
|
|
|
from cv_analysis.locations import TEST_SMILES_FILE
|
|
from cv_analysis.logging import debug_log, logger
|
|
|
|
|
|
class StructuralFormulaImageGenerator:
|
|
def __init__(self, width=None, height=None):
|
|
self.width = width
|
|
self.height = height
|
|
|
|
self.templates = collect_templates()
|
|
self.functional_groups = self.templates.keys()
|
|
|
|
@debug_log()
|
|
def generate_images_from_smiles(self, smiles: List[str], max_images_per_functional_group=1) -> Iterable[Image]:
|
|
yield from self.generate_images_for_functional_groups(
|
|
smiles,
|
|
max_images_per_functional_group=max_images_per_functional_group,
|
|
)
|
|
|
|
@debug_log()
|
|
def generate_images_for_functional_groups(self, smiles: List[str], max_images_per_functional_group):
|
|
for functional_group in self.functional_groups:
|
|
smiles = iter(smiles)
|
|
g = self.generate_images_for_functional_group(smiles, functional_group)
|
|
yield from islice(keep(g), max_images_per_functional_group)
|
|
|
|
def generate_images_for_functional_group(self, smiles: Iterable[str], functional_group: str):
|
|
try:
|
|
yield from self.__generate_images_for_functional_group(smiles, functional_group)
|
|
except ValueError:
|
|
pass
|
|
|
|
@debug_log()
|
|
@retry(10, errors=ValueError)
|
|
def __generate_images_for_functional_group(self, smiles: Iterable[str], functional_group: str):
|
|
|
|
AllChem.Compute2DCoords(self.templates[functional_group])
|
|
|
|
for smile in smiles:
|
|
try:
|
|
image = self.make_image(smile, functional_group)
|
|
yield image
|
|
except ValueError: # SMILE does not match functional group
|
|
raise
|
|
|
|
@debug_log()
|
|
def make_image(self, smile: str, functional_group: str):
|
|
mol = Chem.MolFromSmiles(smile)
|
|
AllChem.GenerateDepictionMatching2DStructure(mol, self.templates[functional_group])
|
|
|
|
side_length = np.random.randint(70, 400)
|
|
width = self.width or side_length
|
|
height = self.height or side_length
|
|
|
|
image: Image = Draw.MolToImage(
|
|
mol,
|
|
size=(width, height),
|
|
kekulize=flip_a_coin(),
|
|
wedgeBonds=flip_a_coin(),
|
|
)
|
|
image.putalpha(255)
|
|
return image
|
|
|
|
|
|
@debug_log()
|
|
def flip_a_coin():
|
|
return bool(np.random.randint(0, 2))
|
|
|
|
|
|
@debug_log()
|
|
def collect_templates():
|
|
functional_groups = FunctionalGroups.BuildFuncGroupHierarchy()
|
|
group_name_2_pattern = dict(stream_label_pattern_pairs(functional_groups))
|
|
return group_name_2_pattern
|
|
|
|
|
|
@debug_log()
|
|
def stream_label_pattern_pairs(functional_groups):
|
|
for functional_group in functional_groups:
|
|
yield functional_group.label, functional_group.pattern
|
|
yield from stream_label_pattern_pairs(functional_group.children)
|
|
|
|
|
|
@debug_log()
|
|
def generate_image_of_structural_formula(smiles_file=None, size=None):
|
|
"""Generate images of formulas from SMILE encoded formulas.
|
|
|
|
Args:
|
|
smiles_file: CSV file with column "smiles". Each row contains a SMILE encoded formula.
|
|
size: width, height
|
|
|
|
Returns:
|
|
PIL.Image.Image: Image of a formula.
|
|
"""
|
|
logger.info(f"Generating structural formula images from {smiles_file}")
|
|
return first(generate_images_of_structural_formulas(smiles_file, size=size))
|
|
|
|
|
|
@debug_log()
|
|
def generate_images_of_structural_formulas(smiles_file=None, size=None):
|
|
"""Generate an image of a formula from SMILE encoded formulas.
|
|
|
|
Args:
|
|
smiles_file: CSV file with column "smiles". Each row contains a SMILE encoded formula.
|
|
size: width, height
|
|
|
|
Yields:
|
|
PIL.Image.Image: Image of a formula.
|
|
"""
|
|
size = size or (None, None)
|
|
smiles_file = smiles_file or TEST_SMILES_FILE
|
|
smiles = pd.read_csv(smiles_file).sample(frac=1).drop_duplicates().smiles
|
|
yield from StructuralFormulaImageGenerator(*size).generate_images_from_smiles(smiles)
|
|
|
|
|
|
# generate_image_of_structural_formula().show()
|