class ViSNet(lmax: int = 1, vecnorm_type: Optional[str] = None, trainable_vecnorm: bool = False, num_heads: int = 8, num_layers: int = 6, hidden_channels: int = 128, num_rbf: int = 32, trainable_rbf: bool = False, max_z: int = 100, cutoff: float = 5.0, max_num_neighbors: int = 32, vertex: bool = False, atomref: Optional[Tensor] = None, reduce_op: str = 'sum', mean: float = 0.0, std: float = 1.0, derivative: bool = False)[source]

Bases: Module

A module that implements the equivariant vector-scalar interactive graph neural network (ViSNet) from the “Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing” paper.

  • lmax (int, optional) – The maximum degree of the spherical harmonics. (default: 1)

  • vecnorm_type (str, optional) – The type of normalization to apply to the vectors. (default: None)

  • trainable_vecnorm (bool, optional) – Whether the normalization weights are trainable. (default: False)

  • num_heads (int, optional) – The number of attention heads. (default: 8)

  • num_layers (int, optional) – The number of layers in the network. (default: 6)

  • hidden_channels (int, optional) – The number of hidden channels in the node embeddings. (default: 128)

  • num_rbf (int, optional) – The number of radial basis functions. (default: 32)

  • trainable_rbf (bool, optional) – Whether the radial basis function parameters are trainable. (default: False)

  • max_z (int, optional) – The maximum atomic numbers. (default: 100)

  • cutoff (float, optional) – The cutoff distance. (default: 5.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors considered for each atom. (default: 32)

  • vertex (bool, optional) – Whether to use vertex geometric features. (default: False)

  • atomref (torch.Tensor, optional) – A tensor of atom reference values, or None if not provided. (default: None)

  • reduce_op (str, optional) – The type of reduction operation to apply ("sum", "mean"). (default: "sum")

  • mean (float, optional) – The mean of the output distribution. (default: 0.0)

  • std (float, optional) – The standard deviation of the output distribution. (default: 1.0)

  • derivative (bool, optional) – Whether to compute the derivative of the output with respect to the positions. (default: False)

forward(z: Tensor, pos: Tensor, batch: Tensor) Tuple[Tensor, Optional[Tensor]][source]

Computes the energies or properties (forces) for a batch of molecules.

  • z (torch.Tensor) – The atomic numbers.

  • pos (torch.Tensor) – The coordinates of the atoms.

  • batch (torch.Tensor) – A batch vector, which assigns each node to a specific example.


The energies or properties for each molecule. dy (torch.Tensor, optional): The negative derivative of energies.

Return type:

y (torch.Tensor)


Resets the parameters of the module.