torch_geometric.nn.pool.ClusterPooling
- class ClusterPooling(in_channels: int, edge_score_method: str = 'tanh', dropout: float = 0.0, threshold: Optional[float] = None)[source]
Bases:
Module
The cluster pooling operator from the “Edge-Based Graph Component Pooling” paper.
ClusterPooling
computes a score for each edge. Based on the selected edges, graph clusters are calculated and compressed to one node using the injective"sum"
aggregation function. Edges are remapped based on the nodes created by each cluster and the original edges.- Parameters:
in_channels (int) – Size of each input sample.
edge_score_method (str, optional) – The function to apply to compute the edge score from raw edge scores (
"tanh"
,"sigmoid"
,"log_softmax"
). (default:"tanh"
)dropout (float, optional) – The probability with which to drop edge scores during training. (default:
0.0
)threshold (float, optional) – The threshold of edge scores. If set to
None
, will be automatically inferred depending onedge_score_method
. (default:None
)
- forward(x: Tensor, edge_index: Tensor, batch: Tensor) Tuple[Tensor, Tensor, Tensor, UnpoolInfo] [source]
Forward pass.
- Parameters:
x (torch.Tensor) – The node features.
edge_index (torch.Tensor) – The edge indices.
batch (torch.Tensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
- Return types:
x (torch.Tensor) - The pooled node features.
edge_index (torch.Tensor) - The coarsened edge indices.
batch (torch.Tensor) - The coarsened batch vector.
unpool_info (UnpoolInfo) - Information that can be consumed for unpooling.