2023-02-15 20:16:16 +01:00

135 lines
4.5 KiB
Python

# Draw molecular structures from smiles. Adapted from https://github.com/neeraj-j/molecules
from itertools import islice
from typing import List, Iterable
import numpy as np
import pandas as pd
from PIL.Image import Image
from funcy import first, retry, keep
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem import FunctionalGroups
from cv_analysis.locations import TEST_SMILES_FILE
from cv_analysis.logging import debug_log, logger
class StructuralFormulaImageGenerator:
def __init__(self, width=None, height=None):
self.width = width
self.height = height
self.templates = collect_templates()
self.functional_groups = self.templates.keys()
@debug_log()
def generate_images_from_smiles(self, smiles: List[str], max_images_per_functional_group=1) -> Iterable[Image]:
yield from self.generate_images_for_functional_groups(
smiles,
max_images_per_functional_group=max_images_per_functional_group,
)
@debug_log()
def generate_images_for_functional_groups(self, smiles: List[str], max_images_per_functional_group):
for functional_group in self.functional_groups:
smiles = iter(smiles)
g = self.generate_images_for_functional_group(smiles, functional_group)
yield from islice(keep(g), max_images_per_functional_group)
def generate_images_for_functional_group(self, smiles: Iterable[str], functional_group: str):
try:
yield from self.__generate_images_for_functional_group(smiles, functional_group)
except ValueError:
pass
@debug_log()
@retry(100, errors=ValueError)
def __generate_images_for_functional_group(self, smiles: Iterable[str], functional_group: str):
AllChem.Compute2DCoords(self.templates[functional_group])
for smile in smiles:
try:
image = self.make_image(smile, functional_group)
yield image
except ValueError: # SMILE does not match functional group
raise
@debug_log()
def make_image(self, smile: str, functional_group: str):
mol = Chem.MolFromSmiles(smile)
AllChem.GenerateDepictionMatching2DStructure(mol, self.templates[functional_group])
side_length = np.random.randint(70, 400)
width = self.width or side_length
height = self.height or side_length
image: Image = Draw.MolToImage(
mol,
size=(width, height),
kekulize=flip_a_coin(),
wedgeBonds=flip_a_coin(),
)
image.putalpha(255)
return image
@debug_log()
def flip_a_coin():
return bool(np.random.randint(0, 2))
@debug_log()
def collect_templates():
functional_groups = FunctionalGroups.BuildFuncGroupHierarchy()
group_name_2_pattern = dict(stream_label_pattern_pairs(functional_groups))
return group_name_2_pattern
@debug_log()
def stream_label_pattern_pairs(functional_groups):
for functional_group in functional_groups:
yield functional_group.label, functional_group.pattern
yield from stream_label_pattern_pairs(functional_group.children)
@debug_log()
def generate_image_of_structural_formula(smiles_file=None, size=None):
"""Generate images of formulas from SMILE encoded formulas.
Args:
smiles_file: CSV file with column "smiles". Each row contains a SMILE encoded formula.
size: width, height
Returns:
PIL.Image.Image: Image of a formula.
"""
logger.info(f"Generating structural formula images from {smiles_file}")
image = first(generate_images_of_structural_formulas(smiles_file, size=size))
if image:
return image
else:
logger.warning(
f"No structural formula images generated from {smiles_file}",
filter=lambda m: not m.startswith("Depict error"),
)
raise ValueError(f"Could not generate structural formula image from {smiles_file}")
@debug_log()
def generate_images_of_structural_formulas(smiles_file=None, size=None):
"""Generate an image of a formula from SMILE encoded formulas.
Args:
smiles_file: CSV file with column "smiles". Each row contains a SMILE encoded formula.
size: width, height
Yields:
PIL.Image.Image: Image of a formula.
"""
size = size or (None, None)
smiles_file = smiles_file or TEST_SMILES_FILE
smiles = pd.read_csv(smiles_file).sample(frac=1).drop_duplicates().smiles
yield from StructuralFormulaImageGenerator(*size).generate_images_from_smiles(smiles)