torch_geometric.nn.models.SchNet

class SchNet(hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable] = None, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: Optional[Tensor] = None)[source]

Bases: Module

The continuous-filter convolutional neural network SchNet from the “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions” paper that uses the interactions blocks of the form.

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]

here \(h_{\mathbf{\Theta}}\) denotes an MLP and \(\mathbf{e}_{j,i}\) denotes the interatomic distances between atoms.

Note

For an example of using a pretrained SchNet variant, see examples/qm9_pretrained_schnet.py.

Parameters:
  • hidden_channels (int, optional) – Hidden embedding size. (default: 128)

  • num_filters (int, optional) – The number of filters to use. (default: 128)

  • num_interactions (int, optional) – The number of interaction blocks. (default: 6)

  • num_gaussians (int, optional) – The number of gaussians \(\mu\). (default: 50)

  • interaction_graph (callable, optional) – The function used to compute the pairwise interaction graph and interatomic distances. If set to None, will construct a graph based on cutoff and max_num_neighbors properties. If provided, this method takes in pos and batch tensors and should return (edge_index, edge_weight) tensors. (default None)

  • cutoff (float, optional) – Cutoff distance for interatomic interactions. (default: 10.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • readout (str, optional) – Whether to apply "add" or "mean" global aggregation. (default: "add")

  • dipole (bool, optional) – If set to True, will use the magnitude of the dipole moment to make the final prediction, e.g., for target 0 of torch_geometric.datasets.QM9. (default: False)

  • mean (float, optional) – The mean of the property to predict. (default: None)

  • std (float, optional) – The standard deviation of the property to predict. (default: None)

  • atomref (torch.Tensor, optional) – The reference of single-atom properties. Expects a vector of shape (max_atomic_number, ).

forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor[source]

Forward pass.

Parameters:
  • z (torch.Tensor) – Atomic number of each atom with shape [num_atoms].

  • pos (torch.Tensor) – Coordinates of each atom with shape [num_atoms, 3].

  • batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

reset_parameters()[source]

Resets all learnable parameters of the module.

static from_qm9_pretrained(root: str, dataset: Dataset, target: int) Tuple[SchNet, Dataset, Dataset, Dataset][source]

Returns a pre-trained SchNet model on the QM9 dataset, trained on the specified target target.