Source code for torch_geometric.utils.smiles

import torch

x_map = {
    'atomic_num':
    list(range(0, 119)),
    'chirality': [
        'CHI_UNSPECIFIED',
        'CHI_TETRAHEDRAL_CW',
        'CHI_TETRAHEDRAL_CCW',
        'CHI_OTHER',
    ],
    'degree':
    list(range(0, 11)),
    'formal_charge':
    list(range(-5, 7)),
    'num_hs':
    list(range(0, 9)),
    'num_radical_electrons':
    list(range(0, 5)),
    'hybridization': [
        'UNSPECIFIED',
        'S',
        'SP',
        'SP2',
        'SP3',
        'SP3D',
        'SP3D2',
        'OTHER',
    ],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
}

e_map = {
    'bond_type': [
        'misc',
        'SINGLE',
        'DOUBLE',
        'TRIPLE',
        'AROMATIC',
    ],
    'stereo': [
        'STEREONONE',
        'STEREOZ',
        'STEREOE',
        'STEREOCIS',
        'STEREOTRANS',
        'STEREOANY',
    ],
    'is_conjugated': [False, True],
}


[docs]def from_smiles(smiles: str, with_hydrogen: bool = False, kekulize: bool = False): r"""Converts a SMILES string to a :class:`torch_geometric.data.Data` instance. Args: smiles (string, optional): The SMILES string. with_hydrogen (bool, optional): If set to :obj:`True`, will store hydrogens in the molecule graph. (default: :obj:`False`) kekulize (bool, optional): If set to :obj:`True`, converts aromatic bonds to single/double bonds. (default: :obj:`False`) """ from rdkit import Chem, RDLogger from torch_geometric.data import Data RDLogger.DisableLog('rdApp.*') mol = Chem.MolFromSmiles(smiles) if mol is None: mol = Chem.MolFromSmiles('') if with_hydrogen: mol = Chem.AddHs(mol) if kekulize: mol = Chem.Kekulize(mol) xs = [] for atom in mol.GetAtoms(): x = [] x.append(x_map['atomic_num'].index(atom.GetAtomicNum())) x.append(x_map['chirality'].index(str(atom.GetChiralTag()))) x.append(x_map['degree'].index(atom.GetTotalDegree())) x.append(x_map['formal_charge'].index(atom.GetFormalCharge())) x.append(x_map['num_hs'].index(atom.GetTotalNumHs())) x.append(x_map['num_radical_electrons'].index( atom.GetNumRadicalElectrons())) x.append(x_map['hybridization'].index(str(atom.GetHybridization()))) x.append(x_map['is_aromatic'].index(atom.GetIsAromatic())) x.append(x_map['is_in_ring'].index(atom.IsInRing())) xs.append(x) x = torch.tensor(xs, dtype=torch.long).view(-1, 9) edge_indices, edge_attrs = [], [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() e = [] e.append(e_map['bond_type'].index(str(bond.GetBondType()))) e.append(e_map['stereo'].index(str(bond.GetStereo()))) e.append(e_map['is_conjugated'].index(bond.GetIsConjugated())) edge_indices += [[i, j], [j, i]] edge_attrs += [e, e] edge_index = torch.tensor(edge_indices) edge_index = edge_index.t().to(torch.long).view(2, -1) edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3) if edge_index.numel() > 0: # Sort indices. perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() edge_index, edge_attr = edge_index[:, perm], edge_attr[perm] return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)