Source code for torch_geometric.nn.models.schnet

import os
import os.path as osp
import warnings
from math import pi as PI
from typing import Callable, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Embedding, Linear, ModuleList, Sequential

from torch_geometric.data import Dataset, download_url, extract_zip
from torch_geometric.data.makedirs import makedirs
from torch_geometric.nn import MessagePassing, SumAggregation, radius_graph
from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
from torch_geometric.typing import OptTensor

qm9_target_dict = {
    0: 'dipole_moment',
    1: 'isotropic_polarizability',
    2: 'homo',
    3: 'lumo',
    4: 'gap',
    5: 'electronic_spatial_extent',
    6: 'zpve',
    7: 'energy_U0',
    8: 'energy_U',
    9: 'enthalpy_H',
    10: 'free_energy',
    11: 'heat_capacity',
}


[docs]class SchNet(torch.nn.Module): r"""The continuous-filter convolutional neural network SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" <https://arxiv.org/abs/1706.08566>`_ paper that uses the interactions blocks of the form .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), here :math:`h_{\mathbf{\Theta}}` denotes an MLP and :math:`\mathbf{e}_{j,i}` denotes the interatomic distances between atoms. .. note:: For an example of using a pretrained SchNet variant, see `examples/qm9_pretrained_schnet.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ qm9_pretrained_schnet.py>`_. Args: hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`) num_filters (int, optional): The number of filters to use. (default: :obj:`128`) num_interactions (int, optional): The number of interaction blocks. (default: :obj:`6`) num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`) interaction_graph (Callable, optional): The function used to compute the pairwise interaction graph and interatomic distances. If set to :obj:`None`, will construct a graph based on :obj:`cutoff` and :obj:`max_num_neighbors` properties. If provided, this method takes in :obj:`pos` and :obj:`batch` tensors and should return :obj:`(edge_index, edge_weight)` tensors. (default :obj:`None`) cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) readout (string, optional): Whether to apply :obj:`"add"` or :obj:`"mean"` global aggregation. (default: :obj:`"add"`) dipole (bool, optional): If set to :obj:`True`, will use the magnitude of the dipole moment to make the final prediction, *e.g.*, for target 0 of :class:`torch_geometric.datasets.QM9`. (default: :obj:`False`) mean (float, optional): The mean of the property to predict. (default: :obj:`None`) std (float, optional): The standard deviation of the property to predict. (default: :obj:`None`) atomref (torch.Tensor, optional): The reference of single-atom properties. Expects a vector of shape :obj:`(max_atomic_number, )`. """ url = 'http://www.quantum-machine.org/datasets/trained_schnet_models.zip' def __init__( self, hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable] = None, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: OptTensor = None, ): super().__init__() self.hidden_channels = hidden_channels self.num_filters = num_filters self.num_interactions = num_interactions self.num_gaussians = num_gaussians self.cutoff = cutoff self.dipole = dipole self.sum_aggr = SumAggregation() self.readout = aggr_resolver('sum' if self.dipole else readout) self.mean = mean self.std = std self.scale = None if self.dipole: import ase atomic_mass = torch.from_numpy(ase.data.atomic_masses) self.register_buffer('atomic_mass', atomic_mass) # Support z == 0 for padding atoms so that their embedding vectors # are zeroed and do not receive any gradients. self.embedding = Embedding(100, hidden_channels, padding_idx=0) if interaction_graph is not None: self.interaction_graph = interaction_graph else: self.interaction_graph = RadiusInteractionGraph( cutoff, max_num_neighbors) self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians) self.interactions = ModuleList() for _ in range(num_interactions): block = InteractionBlock(hidden_channels, num_gaussians, num_filters, cutoff) self.interactions.append(block) self.lin1 = Linear(hidden_channels, hidden_channels // 2) self.act = ShiftedSoftplus() self.lin2 = Linear(hidden_channels // 2, 1) self.register_buffer('initial_atomref', atomref) self.atomref = None if atomref is not None: self.atomref = Embedding(100, 1) self.atomref.weight.data.copy_(atomref) self.reset_parameters()
[docs] def reset_parameters(self): self.embedding.reset_parameters() for interaction in self.interactions: interaction.reset_parameters() torch.nn.init.xavier_uniform_(self.lin1.weight) self.lin1.bias.data.fill_(0) torch.nn.init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0) if self.atomref is not None: self.atomref.weight.data.copy_(self.initial_atomref)
[docs] @staticmethod def from_qm9_pretrained( root: str, dataset: Dataset, target: int, ) -> Tuple['SchNet', Dataset, Dataset, Dataset]: import ase import schnetpack as spk # noqa assert target >= 0 and target <= 12 is_dipole = target == 0 units = [1] * 12 units[0] = ase.units.Debye units[1] = ase.units.Bohr**3 units[5] = ase.units.Bohr**2 root = osp.expanduser(osp.normpath(root)) makedirs(root) folder = 'trained_schnet_models' if not osp.exists(osp.join(root, folder)): path = download_url(SchNet.url, root) extract_zip(path, root) os.unlink(path) name = f'qm9_{qm9_target_dict[target]}' path = osp.join(root, 'trained_schnet_models', name, 'split.npz') split = np.load(path) train_idx = split['train_idx'] val_idx = split['val_idx'] test_idx = split['test_idx'] # Filter the splits to only contain characterized molecules. idx = dataset.data.idx assoc = idx.new_empty(idx.max().item() + 1) assoc[idx] = torch.arange(idx.size(0)) train_idx = assoc[train_idx[np.isin(train_idx, idx)]] val_idx = assoc[val_idx[np.isin(val_idx, idx)]] test_idx = assoc[test_idx[np.isin(test_idx, idx)]] path = osp.join(root, 'trained_schnet_models', name, 'best_model') with warnings.catch_warnings(): warnings.simplefilter('ignore') state = torch.load(path, map_location='cpu') net = SchNet(hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0, dipole=is_dipole, atomref=dataset.atomref(target)) net.embedding.weight = state.representation.embedding.weight for int1, int2 in zip(state.representation.interactions, net.interactions): int2.mlp[0].weight = int1.filter_network[0].weight int2.mlp[0].bias = int1.filter_network[0].bias int2.mlp[2].weight = int1.filter_network[1].weight int2.mlp[2].bias = int1.filter_network[1].bias int2.lin.weight = int1.dense.weight int2.lin.bias = int1.dense.bias int2.conv.lin1.weight = int1.cfconv.in2f.weight int2.conv.lin2.weight = int1.cfconv.f2out.weight int2.conv.lin2.bias = int1.cfconv.f2out.bias net.lin1.weight = state.output_modules[0].out_net[1].out_net[0].weight net.lin1.bias = state.output_modules[0].out_net[1].out_net[0].bias net.lin2.weight = state.output_modules[0].out_net[1].out_net[1].weight net.lin2.bias = state.output_modules[0].out_net[1].out_net[1].bias mean = state.output_modules[0].atom_pool.average net.readout = 'mean' if mean is True else 'add' dipole = state.output_modules[0].__class__.__name__ == 'DipoleMoment' net.dipole = dipole net.mean = state.output_modules[0].standardize.mean.item() net.std = state.output_modules[0].standardize.stddev.item() if state.output_modules[0].atomref is not None: net.atomref.weight = state.output_modules[0].atomref.weight else: net.atomref = None net.scale = 1. / units[target] return net, (dataset[train_idx], dataset[val_idx], dataset[test_idx])
[docs] def forward(self, z: Tensor, pos: Tensor, batch: OptTensor = None) -> Tensor: r""" Args: z (LongTensor): Atomic number of each atom with shape :obj:`[num_atoms]`. pos (Tensor): Coordinates of each atom with shape :obj:`[num_atoms, 3]`. batch (LongTensor, optional): Batch indices assigning each atom to a separate molecule with shape :obj:`[num_atoms]`. (default: :obj:`None`) """ batch = torch.zeros_like(z) if batch is None else batch h = self.embedding(z) edge_index, edge_weight = self.interaction_graph(pos, batch) edge_attr = self.distance_expansion(edge_weight) for interaction in self.interactions: h = h + interaction(h, edge_index, edge_weight, edge_attr) h = self.lin1(h) h = self.act(h) h = self.lin2(h) if self.dipole: # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) M = self.sum_aggr(mass, batch, dim=0) c = self.sum_aggr(mass * pos, batch, dim=0) / M h = h * (pos - c.index_select(0, batch)) if not self.dipole and self.mean is not None and self.std is not None: h = h * self.std + self.mean if not self.dipole and self.atomref is not None: h = h + self.atomref(z) out = self.readout(h, batch, dim=0) if self.dipole: out = torch.norm(out, dim=-1, keepdim=True) if self.scale is not None: out = self.scale * out return out
def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'hidden_channels={self.hidden_channels}, ' f'num_filters={self.num_filters}, ' f'num_interactions={self.num_interactions}, ' f'num_gaussians={self.num_gaussians}, ' f'cutoff={self.cutoff})')
class RadiusInteractionGraph(torch.nn.Module): r"""Creates edges based on atom positions :obj:`pos` to all points within the cutoff distance. Args: cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance with the default interaction graph method. (default: :obj:`32`) """ def __init__(self, cutoff: float = 10.0, max_num_neighbors: int = 32): super().__init__() self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors def forward(self, pos: Tensor, batch: Tensor) -> Tuple[Tensor, Tensor]: r""" Args: pos (Tensor): Coordinates of each atom. batch (LongTensor, optional): Batch indices assigning each atom to a separate molecule. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) row, col = edge_index edge_weight = (pos[row] - pos[col]).norm(dim=-1) return edge_index, edge_weight class InteractionBlock(torch.nn.Module): def __init__(self, hidden_channels: int, num_gaussians: int, num_filters: int, cutoff: float): super().__init__() self.mlp = Sequential( Linear(num_gaussians, num_filters), ShiftedSoftplus(), Linear(num_filters, num_filters), ) self.conv = CFConv(hidden_channels, hidden_channels, num_filters, self.mlp, cutoff) self.act = ShiftedSoftplus() self.lin = Linear(hidden_channels, hidden_channels) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.mlp[0].weight) self.mlp[0].bias.data.fill_(0) torch.nn.init.xavier_uniform_(self.mlp[2].weight) self.mlp[2].bias.data.fill_(0) self.conv.reset_parameters() torch.nn.init.xavier_uniform_(self.lin.weight) self.lin.bias.data.fill_(0) def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor) -> Tensor: x = self.conv(x, edge_index, edge_weight, edge_attr) x = self.act(x) x = self.lin(x) return x class CFConv(MessagePassing): def __init__(self, in_channels: int, out_channels: int, num_filters: int, nn: Sequential, cutoff: float): super().__init__(aggr='add') self.lin1 = Linear(in_channels, num_filters, bias=False) self.lin2 = Linear(num_filters, out_channels) self.nn = nn self.cutoff = cutoff self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.lin1.weight) torch.nn.init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0) def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor) -> Tensor: C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0) W = self.nn(edge_attr) * C.view(-1, 1) x = self.lin1(x) x = self.propagate(edge_index, x=x, W=W) x = self.lin2(x) return x def message(self, x_j: Tensor, W: Tensor) -> Tensor: return x_j * W class GaussianSmearing(torch.nn.Module): def __init__(self, start: float = 0.0, stop: float = 5.0, num_gaussians: int = 50): super().__init__() offset = torch.linspace(start, stop, num_gaussians) self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 self.register_buffer('offset', offset) def forward(self, dist: Tensor) -> Tensor: dist = dist.view(-1, 1) - self.offset.view(1, -1) return torch.exp(self.coeff * torch.pow(dist, 2)) class ShiftedSoftplus(torch.nn.Module): def __init__(self): super().__init__() self.shift = torch.log(torch.tensor(2.0)).item() def forward(self, x: Tensor) -> Tensor: return F.softplus(x) - self.shift