torch_geometric.nn.conv.FusedGATConv
- class FusedGATConv(*args, **kwargs)[source]
Bases:
GATConv
The fused graph attention operator from the “Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective” paper.
FusedGATConv
is an optimized version ofGATConv
based on thedgNN
package that fuses message passing computation for accelerated execution and lower memory footprint.Note
This implementation is based on the
dgNN
package. See here for instructions on how to install.- forward(x: Tensor, csr: Tuple[Tensor, Tensor], csc: Tuple[Tensor, Tensor], perm: Tensor) Tensor [source]
Runs the forward pass of the module.
- Parameters:
x (torch.Tensor) – The node features.
csr ((torch.Tensor, torch.Tensor)) – A tuple containing the CSR representation of a graph, given as a tuple of
(rowptr, col)
.csc ((torch.Tensor, torch.Tensor)) – A tuple containing the CSC representation of a graph, given as a tuple of
(row, colptr)
.perm (torch.Tensor) – Permutation tensor to map the CSR representation to the CSC representation.
Note
Use the
to_graph_format()
method to obtain the(csr, csc, perm)
graph format from an existingedge_index
representation.
- reset_parameters()
Resets all learnable parameters of the module.