import torch
from typing import Union, Tuple
from torch_geometric.typing import OptPairTensor, Adj, Size, Optional
from torch import Tensor
from torch_geometric.nn.dense.linear import Linear
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot
from torch_geometric.utils import softmax
[docs]class GeneralConv(MessagePassing):
r"""A general GNN layer adapted from the `"Design Space for Graph Neural
Networks" <https://arxiv.org/abs/2011.08843>`_ paper.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
in_edge_channels (int, optional): Size of each input edge.
(default: :obj:`None`)
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"mean"`)
skip_linear (bool, optional): Whether apply linear function in skip
connection. (default: :obj:`False`)
directed_msg (bool, optional): If message passing is directed;
otherwise, message passing is bi-directed. (default: :obj:`True`)
heads (int, optional): Number of message passing ensembles.
If :obj:`heads > 1`, the GNN layer will output an ensemble of
multiple messages.
If attention is used (:obj:`attention=True`), this corresponds to
multi-head attention. (default: :obj:`1`)
attention (bool, optional): Whether to add attention to message
computation. (default: :obj:`False`)
attention_type (str, optional): Type of attention: :obj:`"additive"`,
:obj:`"dot_product"`. (default: :obj:`"additive"`)
l2_normalize (bool, optional): If set to :obj:`True`, output features
will be :math:`\ell_2`-normalized, *i.e.*,
:math:`\frac{\mathbf{x}^{\prime}_i}
{\| \mathbf{x}^{\prime}_i \|_2}`.
(default: :obj:`False`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels: Union[int, Tuple[int, int]],
out_channels: Optional[int], in_edge_channels: int = None,
aggr: str = 'add', skip_linear: str = False,
directed_msg: bool = True, heads: int = 1,
attention: bool = False, attention_type: str = 'additive',
l2_normalize: bool = False, bias: bool = True,
**kwargs): # yapf: disable
kwargs.setdefault('aggr', aggr)
super().__init__(node_dim=0, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.in_edge_channels = in_edge_channels
self.aggr = aggr
self.skip_linear = skip_linear
self.directed_msg = directed_msg
self.heads = heads
self.attention = attention
self.attention_type = attention_type
self.normalize_l2 = l2_normalize
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
if self.directed_msg:
self.lin_msg = Linear(in_channels[0], out_channels * self.heads,
bias=bias)
else:
self.lin_msg = Linear(in_channels[0], out_channels * self.heads,
bias=bias)
self.lin_msg_i = Linear(in_channels[0], out_channels * self.heads,
bias=bias)
if self.skip_linear or self.in_channels != self.out_channels:
self.lin_self = Linear(in_channels[1], out_channels, bias=bias)
else:
self.lin_self = torch.nn.Identity()
if self.in_edge_channels is not None:
self.lin_edge = Linear(in_edge_channels, out_channels * self.heads,
bias=bias)
# todo: A general torch_geometric.nn.AttentionLayer
if self.attention:
if self.attention_type == 'additive':
self.att_msg = Parameter(
torch.Tensor(1, self.heads, self.out_channels))
elif self.attention_type == 'dot_product':
self.scaler = torch.sqrt(
torch.tensor(out_channels, dtype=torch.float))
else:
raise ValueError('attention_type: {} not supported'.format(
self.attention_type))
self.reset_parameters()
[docs] def reset_parameters(self):
self.lin_msg.reset_parameters()
self.lin_self.reset_parameters()
if self.in_edge_channels is not None:
self.lin_edge.reset_parameters()
if self.attention and self.attention_type == 'additive':
glorot(self.att_msg)
[docs] def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
size: Size = None, edge_feature: Tensor = None) -> Tensor:
""""""
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
x_self = x[1]
# propagate_type: (x: OptPairTensor)
out = self.propagate(edge_index, x=x, size=size,
edge_feature=edge_feature)
out = out.mean(dim=1) # todo: other approach to aggregate heads
out += self.lin_self(x_self)
if self.normalize_l2:
out = F.normalize(out, p=2, dim=-1)
return out
[docs] def message_basic(self, x_i: Tensor, x_j: Tensor, edge_feature: Tensor):
if self.directed_msg:
x_j = self.lin_msg(x_j)
else:
x_j = self.lin_msg(x_j) + self.lin_msg_i(x_i)
if edge_feature is not None:
x_j = x_j + self.lin_edge(edge_feature)
return x_j
def message(self, x_i: Tensor, x_j: Tensor, edge_index_i: Tensor,
size_i: Tensor, edge_feature: Tensor) -> Tensor:
x_j_out = self.message_basic(x_i, x_j, edge_feature)
x_j_out = x_j_out.view(-1, self.heads, self.out_channels)
if self.attention:
if self.attention_type == 'dot_product':
x_i_out = self.message_basic(x_j, x_i, edge_feature)
x_i_out = x_i_out.view(-1, self.heads, self.out_channels)
alpha = (x_i_out * x_j_out).sum(dim=-1) / self.scaler
else:
alpha = (x_j_out * self.att_msg).sum(dim=-1)
alpha = F.leaky_relu(alpha, negative_slope=0.2)
alpha = softmax(alpha, edge_index_i, num_nodes=size_i)
alpha = alpha.view(-1, self.heads, 1)
return x_j_out * alpha
else:
return x_j_out
def __repr__(self) -> str:
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)