torch_geometric.nn.pool.ASAPooling
- class ASAPooling(in_channels: int, ratio: Union[float, int] = 0.5, GNN: Optional[Callable] = None, dropout: float = 0.0, negative_slope: float = 0.2, add_self_loops: bool = False, **kwargs)[source]
Bases:
Module
The Adaptive Structure Aware Pooling operator from the “ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations” paper.
- Parameters:
in_channels (int) – Size of each input sample.
ratio (float or int) – Graph pooling ratio, which is used to compute \(k = \lceil \mathrm{ratio} \cdot N \rceil\), or the value of \(k\) itself, depending on whether the type of
ratio
isfloat
orint
. (default:0.5
)GNN (torch.nn.Module, optional) – A graph neural network layer for using intra-cluster properties. Especially helpful for graphs with higher degree of neighborhood (one of
torch_geometric.nn.conv.GraphConv
,torch_geometric.nn.conv.GCNConv
or any GNN which supports theedge_weight
parameter). (default:None
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)add_self_loops (bool, optional) – If set to
True
, will add self loops to the new graph connectivity. (default:False
)**kwargs (optional) – Additional parameters for initializing the graph neural network layer.
- forward(x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, batch: Optional[Tensor] = None) Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor] [source]
Forward pass.
- Parameters:
x (torch.Tensor) – The node feature matrix.
edge_index (torch.Tensor) – The edge indices.
edge_weight (torch.Tensor, optional) – The edge weights. (default:
None
)batch (torch.Tensor, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default:
None
)
- Return types:
x (torch.Tensor): The pooled node embeddings.
edge_index (torch.Tensor): The coarsened edge indices.
edge_weight (torch.Tensor, optional): The coarsened edge weights.
batch (torch.Tensor): The coarsened batch vector.
index (torch.Tensor): The top-\(k\) node indices of nodes which are kept after pooling.