Source code for torch_geometric.nn.conv.eg_conv

from typing import List, Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.nn.inits import zeros
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import add_remaining_self_loops
from torch_scatter import scatter
from torch_sparse import SparseTensor, matmul, fill_diag


[docs]class EGConv(MessagePassing): r"""The Efficient Graph Convolution from the `"Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions" <https://arxiv.org/abs/2104.01481>`_ paper. Its node-wise formulation is given by: .. math:: \mathbf{x}_i^{\prime} = {\LARGE ||}_{h=1}^H \sum_{\oplus \in \mathcal{A}} \sum_{b = 1}^B w_{i, h, \oplus, b} \; \underset{j \in \mathcal{N}(i) \cup \{i\}}{\bigoplus} \mathbf{\Theta}_b \mathbf{x}_{j} with :math:`\mathbf{\Theta}_b` denoting a basis weight, :math:`\oplus` denoting an aggregator, and :math:`w` denoting per-vertex weighting coefficients across different heads, bases and aggregators. EGC retains :math:`\mathcal{O}(|\mathcal{V}|)` memory usage, making it a sensible alternative to :class:`~torch_geometric.nn.conv.GCNConv`, :class:`~torch_geometric.nn.conv.SAGEConv` or :class:`~torch_geometric.nn.conv.GINConv`. .. note:: For an example of using :obj:`EGConv`, see `examples/egc.py <https://github.com/pyg-team/pytorch_geometric/blob/master/ examples/egc.py>`_. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. aggregators (List[str], optional): Aggregators to be used. Supported aggregators are :obj:`"sum"`, :obj:`"mean"`, :obj:`"symnorm"`, :obj:`"max"`, :obj:`"min"`, :obj:`"std"`, :obj:`"var"`. Multiple aggregators can be used to improve the performance. (default: :obj:`["symnorm"]`) num_heads (int, optional): Number of heads :math:`H` to use. Must have :obj:`out_channels % num_heads == 0`. It is recommended to set :obj:`num_heads >= num_bases`. (default: :obj:`8`) num_bases (int, optional): Number of basis weights :math:`B` to use. (default: :obj:`4`) cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of the edge index with added self loops on first execution, along with caching the calculation of the symmetric normalized edge weights if the :obj:`"symnorm"` aggregator is being used. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ _cached_edge_index: Optional[Tuple[Tensor, OptTensor]] _cached_adj_t: Optional[SparseTensor] def __init__(self, in_channels: int, out_channels: int, aggregators: List[str] = ["symnorm"], num_heads: int = 8, num_bases: int = 4, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs): super(EGConv, self).__init__(node_dim=0, **kwargs) if out_channels % num_heads != 0: raise ValueError( 'out_channels must be divisible by the number of heads') for a in aggregators: if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']: raise ValueError(f"Unsupported aggregator: '{a}'") self.in_channels = in_channels self.out_channels = out_channels self.num_heads = num_heads self.num_bases = num_bases self.cached = cached self.add_self_loops = add_self_loops self.aggregators = aggregators self.bases_lin = Linear(in_channels, (out_channels // num_heads) * num_bases, bias=False, weight_initializer='glorot') self.comb_lin = Linear(in_channels, num_heads * num_bases * len(aggregators)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): self.bases_lin.reset_parameters() self.comb_lin.reset_parameters() zeros(self.bias) self._cached_adj_t = None self._cached_edge_index = None
[docs] def forward(self, x: Tensor, edge_index: Adj) -> Tensor: """""" symnorm_weight: OptTensor = None if "symnorm" in self.aggregators: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, symnorm_weight = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops) if self.cached: self._cached_edge_index = (edge_index, symnorm_weight) else: edge_index, symnorm_weight = cache elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache elif self.add_self_loops: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if self.cached and cache is not None: edge_index = cache[0] else: edge_index, _ = add_remaining_self_loops(edge_index) if self.cached: self._cached_edge_index = (edge_index, None) elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if self.cached and cache is not None: edge_index = cache else: edge_index = fill_diag(edge_index, 1.0) if self.cached: self._cached_adj_t = edge_index # [num_nodes, (out_channels // num_heads) * num_bases] bases = self.bases_lin(x) # [num_nodes, num_heads * num_bases * num_aggrs] weightings = self.comb_lin(x) # [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases] # propagate_type: (x: Tensor, symnorm_weight: OptTensor) aggregated = self.propagate(edge_index, x=bases, symnorm_weight=symnorm_weight, size=None) weightings = weightings.view(-1, self.num_heads, self.num_bases * len(self.aggregators)) aggregated = aggregated.view( -1, len(self.aggregators) * self.num_bases, self.out_channels // self.num_heads, ) # [num_nodes, num_heads, out_channels // num_heads] out = torch.matmul(weightings, aggregated) out = out.view(-1, self.out_channels) if self.bias is not None: out += self.bias return out
def message(self, x_j: Tensor) -> Tensor: return x_j def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None, symnorm_weight: OptTensor = None) -> Tensor: outs = [] for aggr in self.aggregators: if aggr == 'symnorm': assert symnorm_weight is not None out = scatter(inputs * symnorm_weight.view(-1, 1), index, 0, None, dim_size, reduce='sum') elif aggr == 'var' or aggr == 'std': mean = scatter(inputs, index, 0, None, dim_size, reduce='mean') mean_squares = scatter(inputs * inputs, index, 0, None, dim_size, reduce='mean') out = mean_squares - mean * mean if aggr == 'std': out = torch.sqrt(out.relu_() + 1e-5) else: out = scatter(inputs, index, 0, None, dim_size, reduce=aggr) outs.append(out) return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0] def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: adj_t_2 = adj_t if len(self.aggregators) > 1 and 'symnorm' in self.aggregators: adj_t_2 = adj_t.set_value(None) outs = [] for aggr in self.aggregators: if aggr == 'symnorm': out = matmul(adj_t, x, reduce='sum') elif aggr in ['var', 'std']: mean = matmul(adj_t_2, x, reduce='mean') mean_sq = matmul(adj_t_2, x * x, reduce='mean') out = mean_sq - mean * mean if aggr == 'std': out = torch.sqrt(out.relu_() + 1e-5) else: out = matmul(adj_t_2, x, reduce=aggr) outs.append(out) return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0] def __repr__(self): return '{}({}, {}, aggregators={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.aggregators)