torch_geometric.nn.models.DimeNet
- class DimeNet(hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros')[source]
Bases:
Module
The directional message passing neural network (DimeNet) from the “Directional Message Passing for Molecular Graphs” paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion.
Note
For an example of using a pretrained DimeNet variant, see examples/qm9_pretrained_dimenet.py.
- Parameters:
hidden_channels (int) – Hidden embedding size.
out_channels (int) – Size of each output sample.
num_blocks (int) – Number of building blocks.
num_bilinear (int) – Size of the bilinear layer tensor.
num_spherical (int) – Number of spherical harmonics.
num_radial (int) – Number of radial basis functions.
cutoff (float, optional) – Cutoff distance for interatomic interactions. (default:
5.0
)max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the
cutoff
distance. (default:32
)envelope_exponent (int, optional) – Shape of the smooth cutoff. (default:
5
)num_before_skip (int, optional) – Number of residual layers in the interaction blocks before the skip connection. (default:
1
)num_after_skip (int, optional) – Number of residual layers in the interaction blocks after the skip connection. (default:
2
)num_output_layers (int, optional) – Number of linear layers for the output blocks. (default:
3
)act (str or Callable, optional) – The activation function. (default:
"swish"
)output_initializer (str, optional) – The initialization method for the output layer (
"zeros"
,"glorot_orthogonal"
). (default:"zeros"
)
- forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor [source]
Forward pass.
- Parameters:
z (torch.Tensor) – Atomic number of each atom with shape
[num_atoms]
.pos (torch.Tensor) – Coordinates of each atom with shape
[num_atoms, 3]
.batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape
[num_atoms]
. (default:None
)