torch_geometric.nn.models.DimeNetPlusPlus

class DimeNetPlusPlus(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros')[source]

Bases: DimeNet

The DimeNet++ from the “Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules” paper.

DimeNetPlusPlus is an upgrade to the DimeNet model with 8x faster and 10% more accurate than DimeNet.

Parameters:
  • hidden_channels (int) – Hidden embedding size.

  • out_channels (int) – Size of each output sample.

  • num_blocks (int) – Number of building blocks.

  • int_emb_size (int) – Size of embedding in the interaction block.

  • basis_emb_size (int) – Size of basis embedding in the interaction block.

  • out_emb_channels (int) – Size of embedding in the output block.

  • num_spherical (int) – Number of spherical harmonics.

  • num_radial (int) – Number of radial basis functions.

  • cutoff – (float, optional): Cutoff distance for interatomic interactions. (default: 5.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • envelope_exponent (int, optional) – Shape of the smooth cutoff. (default: 5)

  • num_before_skip – (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: 1)

  • num_after_skip – (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: 2)

  • num_output_layers – (int, optional): Number of linear layers for the output blocks. (default: 3)

  • act – (str or Callable, optional): The activation funtion. (default: "swish")

  • output_initializer (str, optional) – The initialization method for the output layer ("zeros", "glorot_orthogonal"). (default: "zeros")

forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor

Forward pass.

Parameters:
  • z (torch.Tensor) – Atomic number of each atom with shape [num_atoms].

  • pos (torch.Tensor) – Coordinates of each atom with shape [num_atoms, 3].

  • batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

reset_parameters()

Resets all learnable parameters of the module.

classmethod from_qm9_pretrained(root: str, dataset: Dataset, target: int) Tuple[DimeNetPlusPlus, Dataset, Dataset, Dataset][source]

Returns a pre-trained DimeNetPlusPlus model on the QM9 dataset, trained on the specified target target.