import torch
x_map = {
'atomic_num':
list(range(0, 119)),
'chirality': [
'CHI_UNSPECIFIED',
'CHI_TETRAHEDRAL_CW',
'CHI_TETRAHEDRAL_CCW',
'CHI_OTHER',
],
'degree':
list(range(0, 11)),
'formal_charge':
list(range(-5, 7)),
'num_hs':
list(range(0, 9)),
'num_radical_electrons':
list(range(0, 5)),
'hybridization': [
'UNSPECIFIED',
'S',
'SP',
'SP2',
'SP3',
'SP3D',
'SP3D2',
'OTHER',
],
'is_aromatic': [False, True],
'is_in_ring': [False, True],
}
e_map = {
'bond_type': [
'misc',
'SINGLE',
'DOUBLE',
'TRIPLE',
'AROMATIC',
],
'stereo': [
'STEREONONE',
'STEREOZ',
'STEREOE',
'STEREOCIS',
'STEREOTRANS',
'STEREOANY',
],
'is_conjugated': [False, True],
}
[docs]def from_smiles(smiles: str, with_hydrogen: bool = False,
kekulize: bool = False):
r"""Converts a SMILES string to a :class:`torch_geometric.data.Data`
instance.
Args:
smiles (string, optional): The SMILES string.
with_hydrogen (bool, optional): If set to :obj:`True`, will store
hydrogens in the molecule graph. (default: :obj:`False`)
kekulize (bool, optional): If set to :obj:`True`, converts aromatic
bonds to single/double bonds. (default: :obj:`False`)
"""
from rdkit import Chem, RDLogger
from torch_geometric.data import Data
RDLogger.DisableLog('rdApp.*')
mol = Chem.MolFromSmiles(smiles)
if mol is None:
mol = Chem.MolFromSmiles('')
if with_hydrogen:
mol = Chem.AddHs(mol)
if kekulize:
mol = Chem.Kekulize(mol)
xs = []
for atom in mol.GetAtoms():
x = []
x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
x.append(x_map['degree'].index(atom.GetTotalDegree()))
x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
x.append(x_map['num_radical_electrons'].index(
atom.GetNumRadicalElectrons()))
x.append(x_map['hybridization'].index(str(atom.GetHybridization())))
x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
x.append(x_map['is_in_ring'].index(atom.IsInRing()))
xs.append(x)
x = torch.tensor(xs, dtype=torch.long).view(-1, 9)
edge_indices, edge_attrs = [], []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
e = []
e.append(e_map['bond_type'].index(str(bond.GetBondType())))
e.append(e_map['stereo'].index(str(bond.GetStereo())))
e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))
edge_indices += [[i, j], [j, i]]
edge_attrs += [e, e]
edge_index = torch.tensor(edge_indices)
edge_index = edge_index.t().to(torch.long).view(2, -1)
edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)
if edge_index.numel() > 0: # Sort indices.
perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)