Source code for torch_geometric.nn.conv.eg_conv

from typing import List, Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse
from torch_geometric.utils import add_remaining_self_loops, scatter, spmm


[docs]class EGConv(MessagePassing): r"""The Efficient Graph Convolution from the `"Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions" <https://arxiv.org/abs/2104.01481>`_ paper. Its node-wise formulation is given by: .. math:: \mathbf{x}_i^{\prime} = {\LARGE ||}_{h=1}^H \sum_{\oplus \in \mathcal{A}} \sum_{b = 1}^B w_{i, h, \oplus, b} \; \underset{j \in \mathcal{N}(i) \cup \{i\}}{\bigoplus} \mathbf{W}_b \mathbf{x}_{j} with :math:`\mathbf{W}_b` denoting a basis weight, :math:`\oplus` denoting an aggregator, and :math:`w` denoting per-vertex weighting coefficients across different heads, bases and aggregators. EGC retains :math:`\mathcal{O}(|\mathcal{V}|)` memory usage, making it a sensible alternative to :class:`~torch_geometric.nn.conv.GCNConv`, :class:`~torch_geometric.nn.conv.SAGEConv` or :class:`~torch_geometric.nn.conv.GINConv`. .. note:: For an example of using :obj:`EGConv`, see `examples/egc.py <https://github.com/pyg-team/pytorch_geometric/blob/master/ examples/egc.py>`_. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. aggregators (List[str], optional): Aggregators to be used. Supported aggregators are :obj:`"sum"`, :obj:`"mean"`, :obj:`"symnorm"`, :obj:`"max"`, :obj:`"min"`, :obj:`"std"`, :obj:`"var"`. Multiple aggregators can be used to improve the performance. (default: :obj:`["symnorm"]`) num_heads (int, optional): Number of heads :math:`H` to use. Must have :obj:`out_channels % num_heads == 0`. It is recommended to set :obj:`num_heads >= num_bases`. (default: :obj:`8`) num_bases (int, optional): Number of basis weights :math:`B` to use. (default: :obj:`4`) cached (bool, optional): If set to :obj:`True`, the layer will cache the computation of the edge index with added self loops on first execution, along with caching the calculation of the symmetric normalized edge weights if the :obj:`"symnorm"` aggregator is being used. This parameter should only be set to :obj:`True` in transductive learning scenarios. (default: :obj:`False`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ _cached_edge_index: Optional[Tuple[Tensor, OptTensor]] _cached_adj_t: Optional[SparseTensor] def __init__( self, in_channels: int, out_channels: int, aggregators: List[str] = ['symnorm'], num_heads: int = 8, num_bases: int = 4, cached: bool = False, add_self_loops: bool = True, bias: bool = True, **kwargs, ): super().__init__(node_dim=0, **kwargs) if out_channels % num_heads != 0: raise ValueError(f"'out_channels' (got {out_channels}) must be " f"divisible by the number of heads " f"(got {num_heads})") for a in aggregators: if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']: raise ValueError(f"Unsupported aggregator: '{a}'") self.in_channels = in_channels self.out_channels = out_channels self.num_heads = num_heads self.num_bases = num_bases self.cached = cached self.add_self_loops = add_self_loops self.aggregators = aggregators self.bases_lin = Linear(in_channels, (out_channels // num_heads) * num_bases, bias=False, weight_initializer='glorot') self.comb_lin = Linear(in_channels, num_heads * num_bases * len(aggregators)) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() self.bases_lin.reset_parameters() self.comb_lin.reset_parameters() zeros(self.bias) self._cached_adj_t = None self._cached_edge_index = None
[docs] def forward(self, x: Tensor, edge_index: Adj) -> Tensor: symnorm_weight: OptTensor = None if "symnorm" in self.aggregators: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if cache is None: edge_index, symnorm_weight = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops, flow=self.flow, dtype=x.dtype) if self.cached: self._cached_edge_index = (edge_index, symnorm_weight) else: edge_index, symnorm_weight = cache elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if cache is None: edge_index = gcn_norm( # yapf: disable edge_index, None, num_nodes=x.size(self.node_dim), improved=False, add_self_loops=self.add_self_loops, flow=self.flow, dtype=x.dtype) if self.cached: self._cached_adj_t = edge_index else: edge_index = cache elif self.add_self_loops: if isinstance(edge_index, Tensor): cache = self._cached_edge_index if self.cached and cache is not None: edge_index = cache[0] else: edge_index, _ = add_remaining_self_loops(edge_index) if self.cached: self._cached_edge_index = (edge_index, None) elif isinstance(edge_index, SparseTensor): cache = self._cached_adj_t if self.cached and cache is not None: edge_index = cache else: edge_index = torch_sparse.fill_diag(edge_index, 1.0) if self.cached: self._cached_adj_t = edge_index # [num_nodes, (out_channels // num_heads) * num_bases] bases = self.bases_lin(x) # [num_nodes, num_heads * num_bases * num_aggrs] weightings = self.comb_lin(x) # [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases] # propagate_type: (x: Tensor, symnorm_weight: OptTensor) aggregated = self.propagate(edge_index, x=bases, symnorm_weight=symnorm_weight) weightings = weightings.view(-1, self.num_heads, self.num_bases * len(self.aggregators)) aggregated = aggregated.view( -1, len(self.aggregators) * self.num_bases, self.out_channels // self.num_heads, ) # [num_nodes, num_heads, out_channels // num_heads] out = torch.matmul(weightings, aggregated) out = out.view(-1, self.out_channels) if self.bias is not None: out = out + self.bias return out
def message(self, x_j: Tensor) -> Tensor: return x_j def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None, symnorm_weight: OptTensor = None) -> Tensor: outs = [] for aggr in self.aggregators: if aggr == 'symnorm': assert symnorm_weight is not None out = scatter(inputs * symnorm_weight.view(-1, 1), index, 0, dim_size, reduce='sum') elif aggr == 'var' or aggr == 'std': mean = scatter(inputs, index, 0, dim_size, reduce='mean') mean_squares = scatter(inputs * inputs, index, 0, dim_size, reduce='mean') out = mean_squares - mean * mean if aggr == 'std': out = out.clamp(min=1e-5).sqrt() else: out = scatter(inputs, index, 0, dim_size, reduce=aggr) outs.append(out) return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0] def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor: adj_t_2 = adj_t if len(self.aggregators) > 1 and 'symnorm' in self.aggregators: if isinstance(adj_t, SparseTensor): adj_t_2 = adj_t.set_value(None) else: adj_t_2 = adj_t.clone() adj_t_2.values().fill_(1.0) outs = [] for aggr in self.aggregators: if aggr == 'symnorm': out = spmm(adj_t, x, reduce='sum') elif aggr in ['var', 'std']: mean = spmm(adj_t_2, x, reduce='mean') mean_sq = spmm(adj_t_2, x * x, reduce='mean') out = mean_sq - mean * mean if aggr == 'std': out = torch.sqrt(out.relu_() + 1e-5) else: out = spmm(adj_t_2, x, reduce=aggr) outs.append(out) return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0] def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, aggregators={self.aggregators})')