Source code for torch_geometric.nn.models.schnet

import os
import os.path as osp
import warnings
from math import pi as PI
from typing import Callable, Dict, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Embedding, Linear, ModuleList, Sequential

from import Dataset, download_url, extract_zip
from torch_geometric.nn import MessagePassing, SumAggregation, radius_graph
from torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver
from torch_geometric.typing import OptTensor

qm9_target_dict: Dict[int, str] = {
    0: 'dipole_moment',
    1: 'isotropic_polarizability',
    2: 'homo',
    3: 'lumo',
    4: 'gap',
    5: 'electronic_spatial_extent',
    6: 'zpve',
    7: 'energy_U0',
    8: 'energy_U',
    9: 'enthalpy_H',
    10: 'free_energy',
    11: 'heat_capacity',

[docs]class SchNet(torch.nn.Module): r"""The continuous-filter convolutional neural network SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" <>`_ paper that uses the interactions blocks of the form. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), here :math:`h_{\mathbf{\Theta}}` denotes an MLP and :math:`\mathbf{e}_{j,i}` denotes the interatomic distances between atoms. .. note:: For an example of using a pretrained SchNet variant, see `examples/ <>`_. Args: hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`) num_filters (int, optional): The number of filters to use. (default: :obj:`128`) num_interactions (int, optional): The number of interaction blocks. (default: :obj:`6`) num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`) interaction_graph (callable, optional): The function used to compute the pairwise interaction graph and interatomic distances. If set to :obj:`None`, will construct a graph based on :obj:`cutoff` and :obj:`max_num_neighbors` properties. If provided, this method takes in :obj:`pos` and :obj:`batch` tensors and should return :obj:`(edge_index, edge_weight)` tensors. (default :obj:`None`) cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) readout (str, optional): Whether to apply :obj:`"add"` or :obj:`"mean"` global aggregation. (default: :obj:`"add"`) dipole (bool, optional): If set to :obj:`True`, will use the magnitude of the dipole moment to make the final prediction, *e.g.*, for target 0 of :class:`torch_geometric.datasets.QM9`. (default: :obj:`False`) mean (float, optional): The mean of the property to predict. (default: :obj:`None`) std (float, optional): The standard deviation of the property to predict. (default: :obj:`None`) atomref (torch.Tensor, optional): The reference of single-atom properties. Expects a vector of shape :obj:`(max_atomic_number, )`. """ url = '' def __init__( self, hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable] = None, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: OptTensor = None, ): super().__init__() self.hidden_channels = hidden_channels self.num_filters = num_filters self.num_interactions = num_interactions self.num_gaussians = num_gaussians self.cutoff = cutoff self.dipole = dipole self.sum_aggr = SumAggregation() self.readout = aggr_resolver('sum' if self.dipole else readout) self.mean = mean self.std = std self.scale = None if self.dipole: import ase atomic_mass = torch.from_numpy( self.register_buffer('atomic_mass', atomic_mass) # Support z == 0 for padding atoms so that their embedding vectors # are zeroed and do not receive any gradients. self.embedding = Embedding(100, hidden_channels, padding_idx=0) if interaction_graph is not None: self.interaction_graph = interaction_graph else: self.interaction_graph = RadiusInteractionGraph( cutoff, max_num_neighbors) self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians) self.interactions = ModuleList() for _ in range(num_interactions): block = InteractionBlock(hidden_channels, num_gaussians, num_filters, cutoff) self.interactions.append(block) self.lin1 = Linear(hidden_channels, hidden_channels // 2) self.act = ShiftedSoftplus() self.lin2 = Linear(hidden_channels // 2, 1) self.register_buffer('initial_atomref', atomref) self.atomref = None if atomref is not None: self.atomref = Embedding(100, 1) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.embedding.reset_parameters() for interaction in self.interactions: interaction.reset_parameters() torch.nn.init.xavier_uniform_(self.lin1.weight) torch.nn.init.xavier_uniform_(self.lin2.weight) if self.atomref is not None:
[docs] @staticmethod def from_qm9_pretrained( root: str, dataset: Dataset, target: int, ) -> Tuple['SchNet', Dataset, Dataset, Dataset]: # pragma: no cover r"""Returns a pre-trained :class:`SchNet` model on the :class:`~torch_geometric.datasets.QM9` dataset, trained on the specified target :obj:`target`. """ import ase import schnetpack as spk # noqa assert target >= 0 and target <= 12 is_dipole = target == 0 units = [1] * 12 units[0] = ase.units.Debye units[1] = ase.units.Bohr**3 units[5] = ase.units.Bohr**2 root = osp.expanduser(osp.normpath(root)) os.makedirs(root, exist_ok=True) folder = 'trained_schnet_models' if not osp.exists(osp.join(root, folder)): path = download_url(SchNet.url, root) extract_zip(path, root) os.unlink(path) name = f'qm9_{qm9_target_dict[target]}' path = osp.join(root, 'trained_schnet_models', name, 'split.npz') split = np.load(path) train_idx = split['train_idx'] val_idx = split['val_idx'] test_idx = split['test_idx'] # Filter the splits to only contain characterized molecules. idx = assoc = idx.new_empty(idx.max().item() + 1) assoc[idx] = torch.arange(idx.size(0)) train_idx = assoc[train_idx[np.isin(train_idx, idx)]] val_idx = assoc[val_idx[np.isin(val_idx, idx)]] test_idx = assoc[test_idx[np.isin(test_idx, idx)]] path = osp.join(root, 'trained_schnet_models', name, 'best_model') with warnings.catch_warnings(): warnings.simplefilter('ignore') state = torch.load(path, map_location='cpu') net = SchNet( hidden_channels=128, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0, dipole=is_dipole, atomref=dataset.atomref(target), ) net.embedding.weight = state.representation.embedding.weight for int1, int2 in zip(state.representation.interactions, net.interactions): int2.mlp[0].weight = int1.filter_network[0].weight int2.mlp[0].bias = int1.filter_network[0].bias int2.mlp[2].weight = int1.filter_network[1].weight int2.mlp[2].bias = int1.filter_network[1].bias int2.lin.weight = int1.dense.weight int2.lin.bias = int1.dense.bias int2.conv.lin1.weight = int1.cfconv.in2f.weight int2.conv.lin2.weight = int1.cfconv.f2out.weight int2.conv.lin2.bias = int1.cfconv.f2out.bias net.lin1.weight = state.output_modules[0].out_net[1].out_net[0].weight net.lin1.bias = state.output_modules[0].out_net[1].out_net[0].bias net.lin2.weight = state.output_modules[0].out_net[1].out_net[1].weight net.lin2.bias = state.output_modules[0].out_net[1].out_net[1].bias mean = state.output_modules[0].atom_pool.average net.readout = aggr_resolver('mean' if mean is True else 'add') dipole = state.output_modules[0].__class__.__name__ == 'DipoleMoment' net.dipole = dipole net.mean = state.output_modules[0].standardize.mean.item() net.std = state.output_modules[0].standardize.stddev.item() if state.output_modules[0].atomref is not None: net.atomref.weight = state.output_modules[0].atomref.weight else: net.atomref = None net.scale = 1.0 / units[target] return net, (dataset[train_idx], dataset[val_idx], dataset[test_idx])
[docs] def forward(self, z: Tensor, pos: Tensor, batch: OptTensor = None) -> Tensor: r"""Forward pass. Args: z (torch.Tensor): Atomic number of each atom with shape :obj:`[num_atoms]`. pos (torch.Tensor): Coordinates of each atom with shape :obj:`[num_atoms, 3]`. batch (torch.Tensor, optional): Batch indices assigning each atom to a separate molecule with shape :obj:`[num_atoms]`. (default: :obj:`None`) """ batch = torch.zeros_like(z) if batch is None else batch h = self.embedding(z) edge_index, edge_weight = self.interaction_graph(pos, batch) edge_attr = self.distance_expansion(edge_weight) for interaction in self.interactions: h = h + interaction(h, edge_index, edge_weight, edge_attr) h = self.lin1(h) h = self.act(h) h = self.lin2(h) if self.dipole: # Get center of mass. mass = self.atomic_mass[z].view(-1, 1) M = self.sum_aggr(mass, batch, dim=0) c = self.sum_aggr(mass * pos, batch, dim=0) / M h = h * (pos - c.index_select(0, batch)) if not self.dipole and self.mean is not None and self.std is not None: h = h * self.std + self.mean if not self.dipole and self.atomref is not None: h = h + self.atomref(z) out = self.readout(h, batch, dim=0) if self.dipole: out = torch.norm(out, dim=-1, keepdim=True) if self.scale is not None: out = self.scale * out return out
def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'hidden_channels={self.hidden_channels}, ' f'num_filters={self.num_filters}, ' f'num_interactions={self.num_interactions}, ' f'num_gaussians={self.num_gaussians}, ' f'cutoff={self.cutoff})')
class RadiusInteractionGraph(torch.nn.Module): r"""Creates edges based on atom positions :obj:`pos` to all points within the cutoff distance. Args: cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance with the default interaction graph method. (default: :obj:`32`) """ def __init__(self, cutoff: float = 10.0, max_num_neighbors: int = 32): super().__init__() self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors def forward(self, pos: Tensor, batch: Tensor) -> Tuple[Tensor, Tensor]: r"""Forward pass. Args: pos (Tensor): Coordinates of each atom. batch (LongTensor, optional): Batch indices assigning each atom to a separate molecule. :rtype: (:class:`LongTensor`, :class:`Tensor`) """ edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) row, col = edge_index edge_weight = (pos[row] - pos[col]).norm(dim=-1) return edge_index, edge_weight class InteractionBlock(torch.nn.Module): def __init__(self, hidden_channels: int, num_gaussians: int, num_filters: int, cutoff: float): super().__init__() self.mlp = Sequential( Linear(num_gaussians, num_filters), ShiftedSoftplus(), Linear(num_filters, num_filters), ) self.conv = CFConv(hidden_channels, hidden_channels, num_filters, self.mlp, cutoff) self.act = ShiftedSoftplus() self.lin = Linear(hidden_channels, hidden_channels) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.mlp[0].weight) self.mlp[0] torch.nn.init.xavier_uniform_(self.mlp[2].weight) self.mlp[2] self.conv.reset_parameters() torch.nn.init.xavier_uniform_(self.lin.weight) def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor) -> Tensor: x = self.conv(x, edge_index, edge_weight, edge_attr) x = self.act(x) x = self.lin(x) return x class CFConv(MessagePassing): def __init__( self, in_channels: int, out_channels: int, num_filters: int, nn: Sequential, cutoff: float, ): super().__init__(aggr='add') self.lin1 = Linear(in_channels, num_filters, bias=False) self.lin2 = Linear(num_filters, out_channels) self.nn = nn self.cutoff = cutoff self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.lin1.weight) torch.nn.init.xavier_uniform_(self.lin2.weight) def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor) -> Tensor: C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0) W = self.nn(edge_attr) * C.view(-1, 1) x = self.lin1(x) x = self.propagate(edge_index, x=x, W=W) x = self.lin2(x) return x def message(self, x_j: Tensor, W: Tensor) -> Tensor: return x_j * W class GaussianSmearing(torch.nn.Module): def __init__( self, start: float = 0.0, stop: float = 5.0, num_gaussians: int = 50, ): super().__init__() offset = torch.linspace(start, stop, num_gaussians) self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 self.register_buffer('offset', offset) def forward(self, dist: Tensor) -> Tensor: dist = dist.view(-1, 1) - self.offset.view(1, -1) return torch.exp(self.coeff * torch.pow(dist, 2)) class ShiftedSoftplus(torch.nn.Module): def __init__(self): super().__init__() self.shift = torch.log(torch.tensor(2.0)).item() def forward(self, x: Tensor) -> Tensor: return F.softplus(x) - self.shift