torch_geometric.nn.models.SchNet
- class SchNet(hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable] = None, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: Optional[Tensor] = None)[source]
Bases:
Module
The continuous-filter convolutional neural network SchNet from the “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions” paper that uses the interactions blocks of the form.
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]here \(h_{\mathbf{\Theta}}\) denotes an MLP and \(\mathbf{e}_{j,i}\) denotes the interatomic distances between atoms.
Note
For an example of using a pretrained SchNet variant, see examples/qm9_pretrained_schnet.py.
- Parameters:
hidden_channels (int, optional) – Hidden embedding size. (default:
128
)num_filters (int, optional) – The number of filters to use. (default:
128
)num_interactions (int, optional) – The number of interaction blocks. (default:
6
)num_gaussians (int, optional) – The number of gaussians \(\mu\). (default:
50
)interaction_graph (callable, optional) – The function used to compute the pairwise interaction graph and interatomic distances. If set to
None
, will construct a graph based oncutoff
andmax_num_neighbors
properties. If provided, this method takes inpos
andbatch
tensors and should return(edge_index, edge_weight)
tensors. (defaultNone
)cutoff (float, optional) – Cutoff distance for interatomic interactions. (default:
10.0
)max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the
cutoff
distance. (default:32
)readout (str, optional) – Whether to apply
"add"
or"mean"
global aggregation. (default:"add"
)dipole (bool, optional) – If set to
True
, will use the magnitude of the dipole moment to make the final prediction, e.g., for target 0 oftorch_geometric.datasets.QM9
. (default:False
)mean (float, optional) – The mean of the property to predict. (default:
None
)std (float, optional) – The standard deviation of the property to predict. (default:
None
)atomref (torch.Tensor, optional) – The reference of single-atom properties. Expects a vector of shape
(max_atomic_number, )
.
- forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor [source]
Forward pass.
- Parameters:
z (torch.Tensor) – Atomic number of each atom with shape
[num_atoms]
.pos (torch.Tensor) – Coordinates of each atom with shape
[num_atoms, 3]
.batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape
[num_atoms]
. (default:None
)