Source code for torch_geometric.nn.conv.general_conv

import torch
from typing import Union, Tuple
from torch_geometric.typing import OptPairTensor, Adj, Size, Optional

from torch import Tensor
from torch_geometric.nn.dense.linear import Linear
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot
from torch_geometric.utils import softmax


[docs]class GeneralConv(MessagePassing): r"""A general GNN layer adapted from the `"Design Space for Graph Neural Networks" <https://arxiv.org/abs/2011.08843>`_ paper. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. in_edge_channels (int, optional): Size of each input edge. (default: :obj:`None`) aggr (string, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"mean"`) skip_linear (bool, optional): Whether apply linear function in skip connection. (default: :obj:`False`) directed_msg (bool, optional): If message passing is directed; otherwise, message passing is bi-directed. (default: :obj:`True`) heads (int, optional): Number of message passing ensembles. If :obj:`heads > 1`, the GNN layer will output an ensemble of multiple messages. If attention is used (:obj:`attention=True`), this corresponds to multi-head attention. (default: :obj:`1`) attention (bool, optional): Whether to add attention to message computation. (default: :obj:`False`) attention_type (str, optional): Type of attention: :obj:`"additive"`, :obj:`"dot_product"`. (default: :obj:`"additive"`) l2_normalize (bool, optional): If set to :obj:`True`, output features will be :math:`\ell_2`-normalized, *i.e.*, :math:`\frac{\mathbf{x}^{\prime}_i} {\| \mathbf{x}^{\prime}_i \|_2}`. (default: :obj:`False`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: Optional[int], in_edge_channels: int = None, aggr: str = 'add', skip_linear: str = False, directed_msg: bool = True, heads: int = 1, attention: bool = False, attention_type: str = 'additive', l2_normalize: bool = False, bias: bool = True, **kwargs): # yapf: disable kwargs.setdefault('aggr', aggr) super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.in_edge_channels = in_edge_channels self.aggr = aggr self.skip_linear = skip_linear self.directed_msg = directed_msg self.heads = heads self.attention = attention self.attention_type = attention_type self.normalize_l2 = l2_normalize if isinstance(in_channels, int): in_channels = (in_channels, in_channels) if self.directed_msg: self.lin_msg = Linear(in_channels[0], out_channels * self.heads, bias=bias) else: self.lin_msg = Linear(in_channels[0], out_channels * self.heads, bias=bias) self.lin_msg_i = Linear(in_channels[0], out_channels * self.heads, bias=bias) if self.skip_linear or self.in_channels != self.out_channels: self.lin_self = Linear(in_channels[1], out_channels, bias=bias) else: self.lin_self = torch.nn.Identity() if self.in_edge_channels is not None: self.lin_edge = Linear(in_edge_channels, out_channels * self.heads, bias=bias) # todo: A general torch_geometric.nn.AttentionLayer if self.attention: if self.attention_type == 'additive': self.att_msg = Parameter( torch.Tensor(1, self.heads, self.out_channels)) elif self.attention_type == 'dot_product': self.scaler = torch.sqrt( torch.tensor(out_channels, dtype=torch.float)) else: raise ValueError('attention_type: {} not supported'.format( self.attention_type)) self.reset_parameters()
[docs] def reset_parameters(self): self.lin_msg.reset_parameters() self.lin_self.reset_parameters() if self.in_edge_channels is not None: self.lin_edge.reset_parameters() if self.attention and self.attention_type == 'additive': glorot(self.att_msg)
[docs] def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, edge_feature: Tensor = None) -> Tensor: """""" if isinstance(x, Tensor): x: OptPairTensor = (x, x) x_self = x[1] # propagate_type: (x: OptPairTensor) out = self.propagate(edge_index, x=x, size=size, edge_feature=edge_feature) out = out.mean(dim=1) # todo: other approach to aggregate heads out += self.lin_self(x_self) if self.normalize_l2: out = F.normalize(out, p=2, dim=-1) return out
[docs] def message_basic(self, x_i: Tensor, x_j: Tensor, edge_feature: Tensor): if self.directed_msg: x_j = self.lin_msg(x_j) else: x_j = self.lin_msg(x_j) + self.lin_msg_i(x_i) if edge_feature is not None: x_j = x_j + self.lin_edge(edge_feature) return x_j
def message(self, x_i: Tensor, x_j: Tensor, edge_index_i: Tensor, size_i: Tensor, edge_feature: Tensor) -> Tensor: x_j_out = self.message_basic(x_i, x_j, edge_feature) x_j_out = x_j_out.view(-1, self.heads, self.out_channels) if self.attention: if self.attention_type == 'dot_product': x_i_out = self.message_basic(x_j, x_i, edge_feature) x_i_out = x_i_out.view(-1, self.heads, self.out_channels) alpha = (x_i_out * x_j_out).sum(dim=-1) / self.scaler else: alpha = (x_j_out * self.att_msg).sum(dim=-1) alpha = F.leaky_relu(alpha, negative_slope=0.2) alpha = softmax(alpha, edge_index_i, num_nodes=size_i) alpha = alpha.view(-1, self.heads, 1) return x_j_out * alpha else: return x_j_out def __repr__(self) -> str: return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)