# Source code for torch_geometric.nn.pool.glob

from typing import Optional

from torch import Tensor

from torch_geometric.utils import scatter

size: Optional[int] = None) -> Tensor:
r"""Returns batch-wise graph-level-outputs by adding node features
across the node dimension.

For a single graph :math:\mathcal{G}_i, its output is computed by

.. math::
\mathbf{r}_i = \sum_{n=1}^{N_i} \mathbf{x}_n.

Functional method of the
:class:~torch_geometric.nn.aggr.SumAggregation module.

Args:
x (torch.Tensor): Node feature matrix
:math:\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}.
batch (torch.Tensor, optional): The batch vector
:math:\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N, which assigns
each node to a specific example.
size (int, optional): The number of examples :math:B.
Automatically calculated if not given. (default: :obj:None)
"""
dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2

if batch is None:
return x.sum(dim=dim, keepdim=x.dim() <= 2)
return scatter(x, batch, dim=dim, dim_size=size, reduce='sum')

[docs]def global_mean_pool(x: Tensor, batch: Optional[Tensor],
size: Optional[int] = None) -> Tensor:
r"""Returns batch-wise graph-level-outputs by averaging node features
across the node dimension.

For a single graph :math:\mathcal{G}_i, its output is computed by

.. math::
\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n.

Functional method of the
:class:~torch_geometric.nn.aggr.MeanAggregation module.

Args:
x (torch.Tensor): Node feature matrix
:math:\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}.
batch (torch.Tensor, optional): The batch vector
:math:\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N, which assigns
each node to a specific example.
size (int, optional): The number of examples :math:B.
Automatically calculated if not given. (default: :obj:None)
"""
dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2

if batch is None:
return x.mean(dim=dim, keepdim=x.dim() <= 2)
return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')

[docs]def global_max_pool(x: Tensor, batch: Optional[Tensor],
size: Optional[int] = None) -> Tensor:
r"""Returns batch-wise graph-level-outputs by taking the channel-wise
maximum across the node dimension.

For a single graph :math:\mathcal{G}_i, its output is computed by

.. math::
\mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n.

Functional method of the
:class:~torch_geometric.nn.aggr.MaxAggregation module.

Args:
x (torch.Tensor): Node feature matrix
:math:\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}.
batch (torch.Tensor, optional): The batch vector
:math:\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N, which assigns
each element to a specific example.
size (int, optional): The number of examples :math:B.
Automatically calculated if not given. (default: :obj:None)
"""
dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2

if batch is None:
return x.max(dim=dim, keepdim=x.dim() <= 2)
return scatter(x, batch, dim=dim, dim_size=size, reduce='max')