torch_geometric.nn.conv.FusedGATConv

class FusedGATConv(*args, **kwargs)[source]

Bases: GATConv

The fused graph attention operator from the “Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective” paper.

FusedGATConv is an optimized version of GATConv based on the dgNN package that fuses message passing computation for accelerated execution and lower memory footprint.

Note

This implementation is based on the dgNN package. See here for instructions on how to install.

forward(x: Tensor, csr: Tuple[Tensor, Tensor], csc: Tuple[Tensor, Tensor], perm: Tensor) Tensor[source]

Runs the forward pass of the module.

Parameters:
  • x (torch.Tensor) – The node features.

  • csr ((torch.Tensor, torch.Tensor)) – A tuple containing the CSR representation of a graph, given as a tuple of (rowptr, col).

  • csc ((torch.Tensor, torch.Tensor)) – A tuple containing the CSC representation of a graph, given as a tuple of (row, colptr).

  • perm (torch.Tensor) – Permutation tensor to map the CSR representation to the CSC representation.

Note

Use the to_graph_format() method to obtain the (csr, csc, perm) graph format from an existing edge_index representation.

reset_parameters()

Resets all learnable parameters of the module.

static to_graph_format(edge_index: Tensor, size: Optional[Tuple[int, int]] = None) Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor][source]

Converts an edge_index representation of a graph to the desired input format of FusedGATConv.

Parameters:
  • edge_index (torch.Tensor) – The edge indices.

  • size ((int, int), optional) – The shape of edge_index in each dimension. (default: None)