Source code for torch_geometric.io.sdf

import torch
import torch.nn.functional as F
from torch_sparse import coalesce
from torch_geometric.io import parse_txt_array
from torch_geometric.data import Data

elems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}


[docs]def parse_sdf(src): src = src.split('\n')[3:] num_atoms, num_bonds = [int(item) for item in src[0].split()[:2]] atom_block = src[1:num_atoms + 1] pos = parse_txt_array(atom_block, end=3) x = torch.tensor([elems[item.split()[3]] for item in atom_block]) x = F.one_hot(x, num_classes=len(elems)) bond_block = src[1 + num_atoms:1 + num_atoms + num_bonds] row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1 row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) edge_index = torch.stack([row, col], dim=0) edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1 edge_attr = torch.cat([edge_attr, edge_attr], dim=0) edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms, num_atoms) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) return data
[docs]def read_sdf(path): with open(path, 'r') as f: return parse_sdf(f.read())