# Source code for torch_geometric.nn.conv.eg_conv

from typing import List, Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse
from torch_geometric.utils import add_remaining_self_loops, scatter, spmm

[docs]class EGConv(MessagePassing):
r"""The Efficient Graph Convolution from the "Adaptive Filters and
Aggregator Fusion for Efficient Graph Convolutions"
<https://arxiv.org/abs/2104.01481>_ paper.

Its node-wise formulation is given by:

.. math::
\mathbf{x}_i^{\prime} = {\LARGE ||}_{h=1}^H \sum_{\oplus \in
\mathcal{A}} \sum_{b = 1}^B w_{i, h, \oplus, b} \;
\underset{j \in \mathcal{N}(i) \cup \{i\}}{\bigoplus}
\mathbf{W}_b \mathbf{x}_{j}

with :math:\mathbf{W}_b denoting a basis weight,
:math:\oplus denoting an aggregator, and :math:w denoting per-vertex
weighting coefficients across different heads, bases and aggregators.

EGC retains :math:\mathcal{O}(|\mathcal{V}|) memory usage, making it a
sensible alternative to :class:~torch_geometric.nn.conv.GCNConv,
:class:~torch_geometric.nn.conv.SAGEConv or
:class:~torch_geometric.nn.conv.GINConv.

.. note::
For an example of using :obj:EGConv, see examples/egc.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/
examples/egc.py>_.

Args:
in_channels (int): Size of each input sample, or :obj:-1 to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
aggregators (List[str], optional): Aggregators to be used.
Supported aggregators are :obj:"sum", :obj:"mean",
:obj:"symnorm", :obj:"max", :obj:"min", :obj:"std",
:obj:"var".
Multiple aggregators can be used to improve the performance.
(default: :obj:["symnorm"])
num_heads (int, optional): Number of heads :math:H to use. Must have
:obj:out_channels % num_heads == 0. It is recommended to set
:obj:num_heads >= num_bases. (default: :obj:8)
num_bases (int, optional): Number of basis weights :math:B to use.
(default: :obj:4)
cached (bool, optional): If set to :obj:True, the layer will cache
the computation of the edge index with added self loops on first
execution, along with caching the calculation of the symmetric
normalized edge weights if the :obj:"symnorm" aggregator is
being used. This parameter should only be set to :obj:True in
transductive learning scenarios. (default: :obj:False)
add_self_loops (bool, optional): If set to :obj:False, will not add
self-loops to the input graph. (default: :obj:True)
bias (bool, optional): If set to :obj:False, the layer will not learn
an additive bias. (default: :obj:True)
:class:torch_geometric.nn.conv.MessagePassing.

Shapes:
- **input:**
node features :math:(|\mathcal{V}|, F_{in}),
edge indices :math:(2, |\mathcal{E}|)
- **output:** node features :math:(|\mathcal{V}|, F_{out})
"""

_cached_edge_index: Optional[Tuple[Tensor, OptTensor]]

def __init__(
self,
in_channels: int,
out_channels: int,
aggregators: List[str] = ['symnorm'],
num_bases: int = 4,
cached: bool = False,
bias: bool = True,
**kwargs,
):
super().__init__(node_dim=0, **kwargs)

if out_channels % num_heads != 0:
raise ValueError(f"'out_channels' (got {out_channels}) must be "
f"divisible by the number of heads "

for a in aggregators:
if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
raise ValueError(f"Unsupported aggregator: '{a}'")

self.in_channels = in_channels
self.out_channels = out_channels
self.num_bases = num_bases
self.cached = cached
self.aggregators = aggregators

self.bases_lin = Linear(in_channels,
bias=False, weight_initializer='glorot')
self.comb_lin = Linear(in_channels,

if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)

self.reset_parameters()

[docs]    def reset_parameters(self):
super().reset_parameters()
self.bases_lin.reset_parameters()
self.comb_lin.reset_parameters()
zeros(self.bias)
self._cached_edge_index = None

[docs]    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
symnorm_weight: OptTensor = None
if "symnorm" in self.aggregators:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if cache is None:
edge_index, symnorm_weight = gcn_norm(  # yapf: disable
edge_index, None, num_nodes=x.size(self.node_dim),
flow=self.flow, dtype=x.dtype)
if self.cached:
self._cached_edge_index = (edge_index, symnorm_weight)
else:
edge_index, symnorm_weight = cache

elif isinstance(edge_index, SparseTensor):
if cache is None:
edge_index = gcn_norm(  # yapf: disable
edge_index, None, num_nodes=x.size(self.node_dim),
flow=self.flow, dtype=x.dtype)
if self.cached:
else:
edge_index = cache

if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if self.cached and cache is not None:
edge_index = cache[0]
else:
if self.cached:
self._cached_edge_index = (edge_index, None)

elif isinstance(edge_index, SparseTensor):
if self.cached and cache is not None:
edge_index = cache
else:
edge_index = torch_sparse.fill_diag(edge_index, 1.0)
if self.cached:

# [num_nodes, (out_channels // num_heads) * num_bases]
bases = self.bases_lin(x)
# [num_nodes, num_heads * num_bases * num_aggrs]
weightings = self.comb_lin(x)

# [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases]
# propagate_type: (x: Tensor, symnorm_weight: OptTensor)
aggregated = self.propagate(edge_index, x=bases,
symnorm_weight=symnorm_weight)

self.num_bases * len(self.aggregators))
aggregated = aggregated.view(
-1,
len(self.aggregators) * self.num_bases,
)

out = torch.matmul(weightings, aggregated)
out = out.view(-1, self.out_channels)

if self.bias is not None:
out = out + self.bias

return out

def message(self, x_j: Tensor) -> Tensor:
return x_j

def aggregate(self, inputs: Tensor, index: Tensor,
dim_size: Optional[int] = None,
symnorm_weight: OptTensor = None) -> Tensor:

outs = []
for aggr in self.aggregators:
if aggr == 'symnorm':
assert symnorm_weight is not None
out = scatter(inputs * symnorm_weight.view(-1, 1), index, 0,
dim_size, reduce='sum')
elif aggr == 'var' or aggr == 'std':
mean = scatter(inputs, index, 0, dim_size, reduce='mean')
mean_squares = scatter(inputs * inputs, index, 0, dim_size,
reduce='mean')
out = mean_squares - mean * mean
if aggr == 'std':
out = out.clamp(min=1e-5).sqrt()
else:
out = scatter(inputs, index, 0, dim_size, reduce=aggr)

outs.append(out)

if len(self.aggregators) > 1 and 'symnorm' in self.aggregators:
else:

outs = []
for aggr in self.aggregators:
if aggr == 'symnorm':
elif aggr in ['var', 'std']:
mean_sq = spmm(adj_t_2, x * x, reduce='mean')
out = mean_sq - mean * mean
if aggr == 'std':
out = torch.sqrt(out.relu_() + 1e-5)
else: