from typing import Any, Dict, List
import torch
import torch_geometric
x_map: Dict[str, List[Any]] = {
'atomic_num':
list(range(0, 119)),
'chirality': [
'CHI_UNSPECIFIED',
'CHI_TETRAHEDRAL_CW',
'CHI_TETRAHEDRAL_CCW',
'CHI_OTHER',
'CHI_TETRAHEDRAL',
'CHI_ALLENE',
'CHI_SQUAREPLANAR',
'CHI_TRIGONALBIPYRAMIDAL',
'CHI_OCTAHEDRAL',
],
'degree':
list(range(0, 11)),
'formal_charge':
list(range(-5, 7)),
'num_hs':
list(range(0, 9)),
'num_radical_electrons':
list(range(0, 5)),
'hybridization': [
'UNSPECIFIED',
'S',
'SP',
'SP2',
'SP3',
'SP3D',
'SP3D2',
'OTHER',
],
'is_aromatic': [False, True],
'is_in_ring': [False, True],
}
e_map: Dict[str, List[Any]] = {
'bond_type': [
'UNSPECIFIED',
'SINGLE',
'DOUBLE',
'TRIPLE',
'QUADRUPLE',
'QUINTUPLE',
'HEXTUPLE',
'ONEANDAHALF',
'TWOANDAHALF',
'THREEANDAHALF',
'FOURANDAHALF',
'FIVEANDAHALF',
'AROMATIC',
'IONIC',
'HYDROGEN',
'THREECENTER',
'DATIVEONE',
'DATIVE',
'DATIVEL',
'DATIVER',
'OTHER',
'ZERO',
],
'stereo': [
'STEREONONE',
'STEREOANY',
'STEREOZ',
'STEREOE',
'STEREOCIS',
'STEREOTRANS',
],
'is_conjugated': [False, True],
}
[docs]def from_smiles(smiles: str, with_hydrogen: bool = False,
kekulize: bool = False) -> 'torch_geometric.data.Data':
r"""Converts a SMILES string to a :class:`torch_geometric.data.Data`
instance.
Args:
smiles (str): The SMILES string.
with_hydrogen (bool, optional): If set to :obj:`True`, will store
hydrogens in the molecule graph. (default: :obj:`False`)
kekulize (bool, optional): If set to :obj:`True`, converts aromatic
bonds to single/double bonds. (default: :obj:`False`)
"""
from rdkit import Chem, RDLogger
from torch_geometric.data import Data
RDLogger.DisableLog('rdApp.*')
mol = Chem.MolFromSmiles(smiles)
if mol is None:
mol = Chem.MolFromSmiles('')
if with_hydrogen:
mol = Chem.AddHs(mol)
if kekulize:
Chem.Kekulize(mol)
xs: List[List[int]] = []
for atom in mol.GetAtoms():
row: List[int] = []
row.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
row.append(x_map['chirality'].index(str(atom.GetChiralTag())))
row.append(x_map['degree'].index(atom.GetTotalDegree()))
row.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
row.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
row.append(x_map['num_radical_electrons'].index(
atom.GetNumRadicalElectrons()))
row.append(x_map['hybridization'].index(str(atom.GetHybridization())))
row.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
row.append(x_map['is_in_ring'].index(atom.IsInRing()))
xs.append(row)
x = torch.tensor(xs, dtype=torch.long).view(-1, 9)
edge_indices, edge_attrs = [], []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
e = []
e.append(e_map['bond_type'].index(str(bond.GetBondType())))
e.append(e_map['stereo'].index(str(bond.GetStereo())))
e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))
edge_indices += [[i, j], [j, i]]
edge_attrs += [e, e]
edge_index = torch.tensor(edge_indices)
edge_index = edge_index.t().to(torch.long).view(2, -1)
edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)
if edge_index.numel() > 0: # Sort indices.
perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
[docs]def to_smiles(data: 'torch_geometric.data.Data',
kekulize: bool = False) -> Any:
"""Converts a :class:`torch_geometric.data.Data` instance to a SMILES
string.
Args:
data (torch_geometric.data.Data): The molecular graph.
kekulize (bool, optional): If set to :obj:`True`, converts aromatic
bonds to single/double bonds. (default: :obj:`False`)
"""
from rdkit import Chem
mol = Chem.RWMol()
assert data.x is not None
assert data.num_nodes is not None
assert data.edge_index is not None
assert data.edge_attr is not None
for i in range(data.num_nodes):
atom = Chem.Atom(int(data.x[i, 0]))
atom.SetChiralTag(Chem.rdchem.ChiralType.values[int(data.x[i, 1])])
atom.SetFormalCharge(x_map['formal_charge'][int(data.x[i, 3])])
atom.SetNumExplicitHs(x_map['num_hs'][int(data.x[i, 4])])
atom.SetNumRadicalElectrons(x_map['num_radical_electrons'][int(
data.x[i, 5])])
atom.SetHybridization(Chem.rdchem.HybridizationType.values[int(
data.x[i, 6])])
atom.SetIsAromatic(int(data.x[i, 7]))
mol.AddAtom(atom)
edges = [tuple(i) for i in data.edge_index.t().tolist()]
visited = set()
for i in range(len(edges)):
src, dst = edges[i]
if tuple(sorted(edges[i])) in visited:
continue
bond_type = Chem.BondType.values[int(data.edge_attr[i, 0])]
mol.AddBond(src, dst, bond_type)
# Set stereochemistry:
stereo = Chem.rdchem.BondStereo.values[int(data.edge_attr[i, 1])]
if stereo != Chem.rdchem.BondStereo.STEREONONE:
db = mol.GetBondBetweenAtoms(src, dst)
db.SetStereoAtoms(dst, src)
db.SetStereo(stereo)
# Set conjugation:
is_conjugated = bool(data.edge_attr[i, 2])
mol.GetBondBetweenAtoms(src, dst).SetIsConjugated(is_conjugated)
visited.add(tuple(sorted(edges[i])))
mol = mol.GetMol()
if kekulize:
Chem.Kekulize(mol)
Chem.SanitizeMol(mol)
Chem.AssignStereochemistry(mol)
return Chem.MolToSmiles(mol, isomericSmiles=True)