torch_geometric.nn.models.DimeNetPlusPlus
- class DimeNetPlusPlus(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros')[source]
Bases:
DimeNet
The DimeNet++ from the “Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules” paper.
DimeNetPlusPlus
is an upgrade to theDimeNet
model with 8x faster and 10% more accurate thanDimeNet
.- Parameters:
hidden_channels (int) – Hidden embedding size.
out_channels (int) – Size of each output sample.
num_blocks (int) – Number of building blocks.
int_emb_size (int) – Size of embedding in the interaction block.
basis_emb_size (int) – Size of basis embedding in the interaction block.
out_emb_channels (int) – Size of embedding in the output block.
num_spherical (int) – Number of spherical harmonics.
num_radial (int) – Number of radial basis functions.
cutoff – (float, optional): Cutoff distance for interatomic interactions. (default:
5.0
)max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the
cutoff
distance. (default:32
)envelope_exponent (int, optional) – Shape of the smooth cutoff. (default:
5
)num_before_skip – (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default:
1
)num_after_skip – (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default:
2
)num_output_layers – (int, optional): Number of linear layers for the output blocks. (default:
3
)act – (str or Callable, optional): The activation funtion. (default:
"swish"
)output_initializer (str, optional) – The initialization method for the output layer (
"zeros"
,"glorot_orthogonal"
). (default:"zeros"
)
- forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor
Forward pass.
- Parameters:
z (torch.Tensor) – Atomic number of each atom with shape
[num_atoms]
.pos (torch.Tensor) – Coordinates of each atom with shape
[num_atoms, 3]
.batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape
[num_atoms]
. (default:None
)
- reset_parameters()
Resets all learnable parameters of the module.