torch_geometric.nn.dense.DenseSAGEConv

class DenseSAGEConv(in_channels: int, out_channels: int, normalize: bool = False, bias: bool = True)[source]

Bases: Module

See torch_geometric.nn.conv.SAGEConv.

Note

DenseSAGEConv expects to work on binary adjacency matrices. If you want to make use of weighted dense adjacency matrices, please use torch_geometric.nn.dense.DenseGraphConv instead.

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) Tensor[source]

Forward pass.

Parameters:
  • x (torch.Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • adj (torch.Tensor) – Adjacency tensor \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). The adjacency tensor is broadcastable in the batch dimension, resulting in a shared adjacency matrix for the complete batch.

  • mask (torch.Tensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)