torch_geometric.nn.models.DimeNet

class DimeNet(hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros')[source]

Bases: Module

The directional message passing neural network (DimeNet) from the “Directional Message Passing for Molecular Graphs” paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion.

Note

For an example of using a pretrained DimeNet variant, see examples/qm9_pretrained_dimenet.py.

Parameters:
  • hidden_channels (int) – Hidden embedding size.

  • out_channels (int) – Size of each output sample.

  • num_blocks (int) – Number of building blocks.

  • num_bilinear (int) – Size of the bilinear layer tensor.

  • num_spherical (int) – Number of spherical harmonics.

  • num_radial (int) – Number of radial basis functions.

  • cutoff (float, optional) – Cutoff distance for interatomic interactions. (default: 5.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • envelope_exponent (int, optional) – Shape of the smooth cutoff. (default: 5)

  • num_before_skip (int, optional) – Number of residual layers in the interaction blocks before the skip connection. (default: 1)

  • num_after_skip (int, optional) – Number of residual layers in the interaction blocks after the skip connection. (default: 2)

  • num_output_layers (int, optional) – Number of linear layers for the output blocks. (default: 3)

  • act (str or Callable, optional) – The activation function. (default: "swish")

  • output_initializer (str, optional) – The initialization method for the output layer ("zeros", "glorot_orthogonal"). (default: "zeros")

forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor[source]

Forward pass.

Parameters:
  • z (torch.Tensor) – Atomic number of each atom with shape [num_atoms].

  • pos (torch.Tensor) – Coordinates of each atom with shape [num_atoms, 3].

  • batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

reset_parameters()[source]

Resets all learnable parameters of the module.

classmethod from_qm9_pretrained(root: str, dataset: Dataset, target: int) Tuple[DimeNet, Dataset, Dataset, Dataset][source]

Returns a pre-trained DimeNet model on the QM9 dataset, trained on the specified target target.