torch_geometric.nn.pool.ClusterPooling

class ClusterPooling(in_channels: int, edge_score_method: str = 'tanh', dropout: float = 0.0, threshold: Optional[float] = None)[source]

Bases: Module

The cluster pooling operator from the “Edge-Based Graph Component Pooling” paper.

ClusterPooling computes a score for each edge. Based on the selected edges, graph clusters are calculated and compressed to one node using the injective "sum" aggregation function. Edges are remapped based on the nodes created by each cluster and the original edges.

Parameters:
  • in_channels (int) – Size of each input sample.

  • edge_score_method (str, optional) – The function to apply to compute the edge score from raw edge scores ("tanh", "sigmoid", "log_softmax"). (default: "tanh")

  • dropout (float, optional) – The probability with which to drop edge scores during training. (default: 0.0)

  • threshold (float, optional) – The threshold of edge scores. If set to None, will be automatically inferred depending on edge_score_method. (default: None)

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(x: Tensor, edge_index: Tensor, batch: Tensor) Tuple[Tensor, Tensor, Tensor, UnpoolInfo][source]

Forward pass.

Parameters:
  • x (torch.Tensor) – The node features.

  • edge_index (torch.Tensor) – The edge indices.

  • batch (torch.Tensor) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

Return types:
  • x (torch.Tensor) - The pooled node features.

  • edge_index (torch.Tensor) - The coarsened edge indices.

  • batch (torch.Tensor) - The coarsened batch vector.

  • unpool_info (UnpoolInfo) - Information that can be consumed for unpooling.