torch_geometric.nn.pool.avg_pool
- avg_pool(cluster: Tensor, data: Data, transform: Optional[Callable] = None) Data [source]
Pools and coarsens a graph given by the
torch_geometric.data.Data
object according to the clustering defined incluster
. Final node features are defined by the average features of all nodes within the same cluster. Seetorch_geometric.nn.pool.max_pool()
for more details.- Parameters:
cluster (torch.Tensor) – The cluster vector \(\mathbf{c} \in \{ 0, \ldots, N - 1 \}^N\), which assigns each node to a specific cluster.
data (Data) – Graph data object.
transform (callable, optional) – A function/transform that takes in the coarsened and pooled
torch_geometric.data.Data
object and returns a transformed version. (default:None
)
- Return type: