Source code for torch_geometric.nn.conv.gat_conv

import typing
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import (
    Adj,
    NoneType,
    OptPairTensor,
    OptTensor,
    Size,
    SparseTensor,
    torch_sparse,
)
from torch_geometric.utils import (
    add_self_loops,
    is_torch_sparse_tensor,
    remove_self_loops,
    softmax,
)
from torch_geometric.utils.sparse import set_sparse_value

if typing.TYPE_CHECKING:
    from typing import overload
else:
    from torch.jit import _overload_method as overload


[docs]class GATConv(MessagePassing): r"""The graph attentional operator from the `"Graph Attention Networks" <https://arxiv.org/abs/1710.10903>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}_{s}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j}, where the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t}\mathbf{\Theta}_{t}\mathbf{x}_k \right)\right)}. If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: \alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_k + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}. If the graph is not bipartite, :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities in case of a bipartite graph. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default: :obj:`None`) fill_value (float or torch.Tensor or str, optional): The way to generate edge features of self-loops (in case :obj:`edge_dim != None`). If given as :obj:`float` or :class:`torch.Tensor`, edge features of self-loops will be directly given by :obj:`fill_value`. If given as :obj:`str`, edge features of self-loops are computed by aggregating all features of edges that point to the specific node, according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite. If :obj:`return_attention_weights=True`, then :math:`((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` if bipartite """ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, edge_dim: Optional[int] = None, fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.add_self_loops = add_self_loops self.edge_dim = edge_dim self.fill_value = fill_value # In case we are operating in bipartite graphs, we apply separate # transformations 'lin_src' and 'lin_dst' to source and target nodes: self.lin = self.lin_src = self.lin_dst = None if isinstance(in_channels, int): self.lin = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') else: self.lin_src = Linear(in_channels[0], heads * out_channels, False, weight_initializer='glorot') self.lin_dst = Linear(in_channels[1], heads * out_channels, False, weight_initializer='glorot') # The learnable parameters to compute attention coefficients: self.att_src = Parameter(torch.empty(1, heads, out_channels)) self.att_dst = Parameter(torch.empty(1, heads, out_channels)) if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, weight_initializer='glorot') self.att_edge = Parameter(torch.empty(1, heads, out_channels)) else: self.lin_edge = None self.register_parameter('att_edge', None) if bias and concat: self.bias = Parameter(torch.empty(heads * out_channels)) elif bias and not concat: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() if self.lin is not None: self.lin.reset_parameters() if self.lin_src is not None: self.lin_src.reset_parameters() if self.lin_dst is not None: self.lin_dst.reset_parameters() if self.lin_edge is not None: self.lin_edge.reset_parameters() glorot(self.att_src) glorot(self.att_dst) glorot(self.att_edge) zeros(self.bias)
@overload def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: NoneType = None, ) -> Tensor: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, OptPairTensor], edge_index: Tensor, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, OptPairTensor], edge_index: SparseTensor, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, SparseTensor]: pass
[docs] def forward( # noqa: F811 self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, return_attention_weights: Optional[bool] = None, ) -> Union[ Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]], Tuple[Tensor, SparseTensor], ]: r"""Runs the forward pass of the module. Args: x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) size ((int, int), optional): The shape of the adjacency matrix. (default: :obj:`None`) return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ # NOTE: attention weights will be returned whenever # `return_attention_weights` is set to a value, regardless of its # actual value (might be `True` or `False`). This is a current somewhat # hacky workaround to allow for TorchScript support via the # `torch.jit._overload` decorator, as we can only change the output # arguments conditioned on type (`None` or `bool`), not based on its # actual value. H, C = self.heads, self.out_channels # We first transform the input node features. If a tuple is passed, we # transform source and target node features via separate weights: if isinstance(x, Tensor): assert x.dim() == 2, "Static graphs not supported in 'GATConv'" if self.lin is not None: x_src = x_dst = self.lin(x).view(-1, H, C) else: # If the module is initialized as bipartite, transform source # and destination node features separately: assert self.lin_src is not None and self.lin_dst is not None x_src = self.lin_src(x).view(-1, H, C) x_dst = self.lin_dst(x).view(-1, H, C) else: # Tuple of source and target node features: x_src, x_dst = x assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" if self.lin is not None: # If the module is initialized as non-bipartite, we expect that # source and destination node features have the same shape and # that they their transformations are shared: x_src = self.lin(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin(x_dst).view(-1, H, C) else: assert self.lin_src is not None and self.lin_dst is not None x_src = self.lin_src(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) x = (x_src, x_dst) # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) if self.add_self_loops: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: num_nodes = x_src.size(0) if x_dst is not None: num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) elif isinstance(edge_index, SparseTensor): if self.edge_dim is None: edge_index = torch_sparse.set_diag(edge_index) else: raise NotImplementedError( "The usage of 'edge_attr' and 'add_self_loops' " "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr, size=size) # propagate_type: (x: OptPairTensor, alpha: Tensor) out = self.propagate(edge_index, x=x, alpha=alpha, size=size) if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out = out + self.bias if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): # TODO TorchScript requires to return a tuple adj = set_sparse_value(edge_index, alpha) return out, (adj, alpha) else: return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, dim_size: Optional[int]) -> Tensor: # Given edge-level attention coefficients for source and target nodes, # we simply need to sum them up to "emulate" concatenation: alpha = alpha_j if alpha_i is None else alpha_j + alpha_i if index.numel() == 0: return alpha if edge_attr is not None and self.lin_edge is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) edge_attr = self.lin_edge(edge_attr) edge_attr = edge_attr.view(-1, self.heads, self.out_channels) alpha_edge = (edge_attr * self.att_edge).sum(dim=-1) alpha = alpha + alpha_edge alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, index, ptr, dim_size) alpha = F.dropout(alpha, p=self.dropout, training=self.training) return alpha def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: return alpha.unsqueeze(-1) * x_j def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})')