torch_geometric.nn.models.ViSNet
- class ViSNet(lmax: int = 1, vecnorm_type: Optional[str] = None, trainable_vecnorm: bool = False, num_heads: int = 8, num_layers: int = 6, hidden_channels: int = 128, num_rbf: int = 32, trainable_rbf: bool = False, max_z: int = 100, cutoff: float = 5.0, max_num_neighbors: int = 32, vertex: bool = False, atomref: Optional[Tensor] = None, reduce_op: str = 'sum', mean: float = 0.0, std: float = 1.0, derivative: bool = False)[source]
Bases:
Module
A PyTorch module that implements the equivariant vector-scalar interactive graph neural network (ViSNet) from the “Enhancing Geometric Representations for Molecules with Equivariant Vector-Scalar Interactive Message Passing” paper.
- Parameters:
lmax (int, optional) – The maximum degree of the spherical harmonics. (default:
1
)vecnorm_type (str, optional) – The type of normalization to apply to the vectors. (default:
None
)trainable_vecnorm (bool, optional) – Whether the normalization weights are trainable. (default:
False
)num_heads (int, optional) – The number of attention heads. (default:
8
)num_layers (int, optional) – The number of layers in the network. (default:
6
)hidden_channels (int, optional) – The number of hidden channels in the node embeddings. (default:
128
)num_rbf (int, optional) – The number of radial basis functions. (default:
32
)trainable_rbf (bool, optional) – Whether the radial basis function parameters are trainable. (default:
False
)max_z (int, optional) – The maximum atomic numbers. (default:
100
)cutoff (float, optional) – The cutoff distance. (default:
5.0
)max_num_neighbors (int, optional) – The maximum number of neighbors considered for each atom. (default:
32
)vertex (bool, optional) – Whether to use vertex geometric features. (default:
False
)atomref (torch.Tensor, optional) – A tensor of atom reference values, or
None
if not provided. (default:None
)reduce_op (str, optional) – The type of reduction operation to apply (
"sum"
,"mean"
). (default:"sum"
)mean (float, optional) – The mean of the output distribution. (default:
0.0
)std (float, optional) – The standard deviation of the output distribution. (default:
1.0
)derivative (bool, optional) – Whether to compute the derivative of the output with respect to the positions. (default:
False
)
- forward(z: Tensor, pos: Tensor, batch: Tensor) Tuple[Tensor, Optional[Tensor]] [source]
Computes the energies or properties (forces) for a batch of molecules.
- Parameters:
z (torch.Tensor) – The atomic numbers.
pos (torch.Tensor) – The coordinates of the atoms.
batch (torch.Tensor) – A batch vector, which assigns each node to a specific example.
- Returns:
The energies or properties for each molecule. dy (torch.Tensor, optional): The negative derivative of energies.
- Return type:
y (torch.Tensor)