# Source code for torch_geometric.nn.conv.gatv2_conv

from typing import Union, Tuple, Optional
from torch_geometric.typing import (Adj, Size, OptTensor, PairTensor)

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

from torch_geometric.nn.inits import glorot, zeros

[docs]class GATv2Conv(MessagePassing):
r"""The GATv2 operator from the "How Attentive are Graph Attention Networks?"
<https://arxiv.org/abs/2105.14491>_ paper, which fixes the static
attention problem of the standard :class:~torch_geometric.conv.GATConv
layer: since the linear layers in the standard GAT are applied right after
each other, the ranking of attended nodes is unconditioned on the query
node. In contrast, in GATv2, every node can attend to any other node.

.. math::
\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},

where the attention coefficients :math:\alpha_{i,j} are computed as

.. math::
\alpha_{i,j} =
\frac{
\exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta}
[\mathbf{x}_i \, \Vert \, \mathbf{x}_j]
\right)\right)}
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
\exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta}
[\mathbf{x}_i \, \Vert \, \mathbf{x}_k]
\right)\right)}.

Args:
in_channels (int): Size of each input sample, or :obj:-1 to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
(default: :obj:1)
concat (bool, optional): If set to :obj:False, the multi-head
attentions are averaged instead of concatenated.
(default: :obj:True)
negative_slope (float, optional): LeakyReLU angle of the negative
slope. (default: :obj:0.2)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:0)
add_self_loops (bool, optional): If set to :obj:False, will not add
self-loops to the input graph. (default: :obj:True)
bias (bool, optional): If set to :obj:False, the layer will not learn
an additive bias. (default: :obj:True)
share_weights (bool, optional): If set to :obj:True, the same matrix
will be applied to the source and the target node of every edge.
(default: :obj:False)
:class:torch_geometric.nn.conv.MessagePassing.
"""
_alpha: OptTensor

def __init__(self, in_channels: int, out_channels: int, heads: int = 1,
concat: bool = True, negative_slope: float = 0.2,
dropout: float = 0., add_self_loops: bool = True,
bias: bool = True, share_weights: bool = False, **kwargs):
super(GATv2Conv, self).__init__(node_dim=0, **kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.share_weights = share_weights

self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,
weight_initializer='glorot')
if share_weights:
self.lin_r = self.lin_l
else:
self.lin_r = Linear(in_channels, heads * out_channels, bias=bias,
weight_initializer='glorot')

if bias and concat:
elif bias and not concat:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)

self._alpha = None

self.reset_parameters()

[docs]    def reset_parameters(self):
self.lin_l.reset_parameters()
self.lin_r.reset_parameters()
glorot(self.att)
zeros(self.bias)

[docs]    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
size: Size = None, return_attention_weights: bool = None):
# type: (Union[Tensor, PairTensor], Tensor, Size, NoneType) -> Tensor  # noqa
# type: (Union[Tensor, PairTensor], SparseTensor, Size, NoneType) -> Tensor  # noqa
# type: (Union[Tensor, PairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
# type: (Union[Tensor, PairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor]  # noqa
r"""
Args:
return_attention_weights (bool, optional): If set to :obj:True,
:obj:(edge_index, attention_weights), holding the computed
attention weights for each edge. (default: :obj:None)
"""

x_l: OptTensor = None
x_r: OptTensor = None
if isinstance(x, Tensor):
assert x.dim() == 2
x_l = self.lin_l(x).view(-1, H, C)
if self.share_weights:
x_r = x_l
else:
x_r = self.lin_r(x).view(-1, H, C)
else:
x_l, x_r = x, x
assert x.dim() == 2
x_l = self.lin_l(x_l).view(-1, H, C)
if x_r is not None:
x_r = self.lin_r(x_r).view(-1, H, C)

assert x_l is not None
assert x_r is not None

if isinstance(edge_index, Tensor):
num_nodes = x_l.size(0)
if x_r is not None:
num_nodes = min(num_nodes, x_r.size(0))
if size is not None:
num_nodes = min(size, size)
edge_index, _ = remove_self_loops(edge_index)
elif isinstance(edge_index, SparseTensor):
edge_index = set_diag(edge_index)

# propagate_type: (x: PairTensor)
out = self.propagate(edge_index, x=(x_l, x_r), size=size)

alpha = self._alpha
self._alpha = None

if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)

if self.bias is not None:
out += self.bias

if isinstance(return_attention_weights, bool):
assert alpha is not None
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
else:
return out

def message(self, x_j: Tensor, x_i: Tensor, index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
x = x_i + x_j
x = F.leaky_relu(x, self.negative_slope)
alpha = (x * self.att).sum(dim=-1)
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return x_j * alpha.unsqueeze(-1)

def __repr__(self):