class ASAPooling(in_channels: int, ratio: Union[float, int] = 0.5, GNN: Optional[Callable] = None, dropout: float = 0.0, negative_slope: float = 0.2, add_self_loops: bool = False, **kwargs)[source]

Bases: Module

The Adaptive Structure Aware Pooling operator from the “ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations” paper.

  • in_channels (int) – Size of each input sample.

  • ratio (float or int) – Graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\), or the value of \(k\) itself, depending on whether the type of ratio is float or int. (default: 0.5)

  • GNN (torch.nn.Module, optional) – A graph neural network layer for using intra-cluster properties. Especially helpful for graphs with higher degree of neighborhood (one of torch_geometric.nn.conv.GraphConv, torch_geometric.nn.conv.GCNConv or any GNN which supports the edge_weight parameter). (default: None)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • add_self_loops (bool, optional) – If set to True, will add self loops to the new graph connectivity. (default: False)

  • **kwargs (optional) – Additional parameters for initializing the graph neural network layer.


Resets all learnable parameters of the module.

forward(x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, batch: Optional[Tensor] = None) Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor][source]

Forward pass.

  • x (torch.Tensor) – The node feature matrix.

  • edge_index (torch.Tensor) – The edge indices.

  • edge_weight (torch.Tensor, optional) – The edge weights. (default: None)

  • batch (torch.Tensor, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

Return types:
  • x (torch.Tensor): The pooled node embeddings.

  • edge_index (torch.Tensor): The coarsened edge indices.

  • edge_weight (torch.Tensor, optional): The coarsened edge weights.

  • batch (torch.Tensor): The coarsened batch vector.

  • index (torch.Tensor): The top-\(k\) node indices of nodes which are kept after pooling.