Source code for torch_geometric.nn.conv.gatv2_conv

from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from torch_sparse import SparseTensor, set_diag

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax


[docs]class GATv2Conv(MessagePassing): r"""The GATv2 operator from the `"How Attentive are Graph Attention Networks?" <https://arxiv.org/abs/2105.14491>`_ paper, which fixes the static attention problem of the standard :class:`~torch_geometric.conv.GATConv` layer: since the linear layers in the standard GAT are applied right after each other, the ranking of attended nodes is unconditioned on the query node. In contrast, in GATv2, every node can attend to any other node. .. math:: \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, where the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_k] \right)\right)}. If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_j \, \Vert \, \mathbf{e}_{i,j}] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} [\mathbf{x}_i \, \Vert \, \mathbf{x}_k \, \Vert \, \mathbf{e}_{i,k}] \right)\right)}. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default: :obj:`None`) fill_value (float or Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`edge_dim != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) share_weights (bool, optional): If set to :obj:`True`, the same matrix will be applied to the source and the target node of every edge. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite. If :obj:`return_attention_weights=True`, then :math:`((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` if bipartite """ _alpha: OptTensor def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, edge_dim: Optional[int] = None, fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, share_weights: bool = False, **kwargs, ): super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.add_self_loops = add_self_loops self.edge_dim = edge_dim self.fill_value = fill_value self.share_weights = share_weights if isinstance(in_channels, int): self.lin_l = Linear(in_channels, heads * out_channels, bias=bias, weight_initializer='glorot') if share_weights: self.lin_r = self.lin_l else: self.lin_r = Linear(in_channels, heads * out_channels, bias=bias, weight_initializer='glorot') else: self.lin_l = Linear(in_channels[0], heads * out_channels, bias=bias, weight_initializer='glorot') if share_weights: self.lin_r = self.lin_l else: self.lin_r = Linear(in_channels[1], heads * out_channels, bias=bias, weight_initializer='glorot') self.att = Parameter(torch.Tensor(1, heads, out_channels)) if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, weight_initializer='glorot') else: self.lin_edge = None if bias and concat: self.bias = Parameter(torch.Tensor(heads * out_channels)) elif bias and not concat: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self._alpha = None self.reset_parameters()
[docs] def reset_parameters(self): self.lin_l.reset_parameters() self.lin_r.reset_parameters() if self.lin_edge is not None: self.lin_edge.reset_parameters() glorot(self.att) zeros(self.bias)
[docs] def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights: bool = None): # type: (Union[Tensor, PairTensor], Tensor, OptTensor, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, NoneType) -> Tensor # noqa # type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels x_l: OptTensor = None x_r: OptTensor = None if isinstance(x, Tensor): assert x.dim() == 2 x_l = self.lin_l(x).view(-1, H, C) if self.share_weights: x_r = x_l else: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] assert x[0].dim() == 2 x_l = self.lin_l(x_l).view(-1, H, C) if x_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) assert x_l is not None assert x_r is not None if self.add_self_loops: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # propagate_type: (x: PairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, size=None) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out += self.bias if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def message(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: x = x_i + x_j if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) x += edge_attr x = F.leaky_relu(x, self.negative_slope) alpha = (x * self.att).sum(dim=-1) alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, training=self.training) return x_j * alpha.unsqueeze(-1) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})')