Source code for torch_geometric.nn.conv.gatv2_conv

from typing import Union, Tuple, Optional
from torch_geometric.typing import (Adj, Size, OptTensor, PairTensor)

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

from torch_geometric.nn.inits import glorot, zeros

[docs]class GATv2Conv(MessagePassing): r"""The GATv2 operator from the `"How Attentive are Graph Attention Networks?" <>`_ paper, which fixes the static attention problem of the standard :class:`~torch_geometric.conv.GATConv` layer: since the linear layers in the standard GAT are applied right after each other, the ranking of attended nodes is unconditioned on the query node. In contrast, in GATv2, every node can attend to any other node. .. math:: \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, where the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_k] \right)\right)}. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) share_weights (bool, optional): If set to :obj:`True`, the same matrix will be applied to the source and the target node of every edge. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ _alpha: OptTensor def __init__(self, in_channels: int, out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0., add_self_loops: bool = True, bias: bool = True, share_weights: bool = False, **kwargs): super(GATv2Conv, self).__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.add_self_loops = add_self_loops self.share_weights = share_weights self.lin_l = Linear(in_channels, heads * out_channels, bias=bias) if share_weights: self.lin_r = self.lin_l else: self.lin_r = Linear(in_channels, heads * out_channels, bias=bias) self.att = Parameter(torch.Tensor(1, heads, out_channels)) if bias and concat: self.bias = Parameter(torch.Tensor(heads * out_channels)) elif bias and not concat: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self._alpha = None self.reset_parameters()
[docs] def reset_parameters(self): glorot(self.lin_l.weight) glorot(self.lin_r.weight) glorot(self.att) zeros(self.bias)
[docs] def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, size: Size = None, return_attention_weights: bool = None): # type: (Union[Tensor, PairTensor], Tensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, PairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2 x_l = self.lin_l(x).view(-1, H, C) if self.share_weights: x_r = x_l else: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2 x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None assert x_r is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) if size is not None: num_nodes = min(size[0], size[1]) edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): edge_index = set_diag(edge_index) # propagate_type: (x: PairTensor) out = self.propagate(edge_index, x=(x_l, x_r), size=size) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def message(self, x_j: Tensor, x_i: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: x = x_i + x_j x = F.leaky_relu(x, self.negative_slope) alpha = (x * self.att).sum(dim=-1) alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, return x_j * alpha.unsqueeze(-1) def __repr__(self): return '{}({}, {}, heads={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads)