Source code for torch_geometric.nn.conv.transformer_conv

import math
import typing
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
    Adj,
    NoneType,
    OptTensor,
    PairTensor,
    SparseTensor,
)
from torch_geometric.utils import softmax

if typing.TYPE_CHECKING:
    from typing import overload
else:
    from torch.jit import _overload_method as overload


[docs]class TransformerConv(MessagePassing): r"""The graph transformer operator from the `"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" <https://arxiv.org/abs/2009.03509>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j}, where the attention coefficients :math:`\alpha_{i,j}` are computed via multi-head dot product attention: .. math:: \alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} {\sqrt{d}} \right) Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) beta (bool, optional): If set, will combine aggregation and skip information via .. math:: \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i} with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 \mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). Edge features are added to the keys after linear transformation, that is, prior to computing the attention dot product. They are also added to final values after the same linear transformation. The model is: .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} \right), where the attention coefficients :math:`\alpha_{i,j}` are now computed via: .. math:: \alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} {\sqrt{d}} \right) (default :obj:`None`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add the transformed root node features to the output and the option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ _alpha: OptTensor def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, beta: bool = False, dropout: float = 0., edge_dim: Optional[int] = None, bias: bool = True, root_weight: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.beta = beta and root_weight self.root_weight = root_weight self.concat = concat self.dropout = dropout self.edge_dim = edge_dim self._alpha = None if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lin_key = Linear(in_channels[0], heads * out_channels) self.lin_query = Linear(in_channels[1], heads * out_channels) self.lin_value = Linear(in_channels[0], heads * out_channels) if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False) else: self.lin_edge = self.register_parameter('lin_edge', None) if concat: self.lin_skip = Linear(in_channels[1], heads * out_channels, bias=bias) if self.beta: self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False) else: self.lin_beta = self.register_parameter('lin_beta', None) else: self.lin_skip = Linear(in_channels[1], out_channels, bias=bias) if self.beta: self.lin_beta = Linear(3 * out_channels, 1, bias=False) else: self.lin_beta = self.register_parameter('lin_beta', None) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() self.lin_key.reset_parameters() self.lin_query.reset_parameters() self.lin_value.reset_parameters() if self.edge_dim: self.lin_edge.reset_parameters() self.lin_skip.reset_parameters() if self.beta: self.lin_beta.reset_parameters()
@overload def forward( self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights: NoneType = None, ) -> Tensor: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, PairTensor], edge_index: Tensor, edge_attr: OptTensor = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: pass @overload def forward( # noqa: F811 self, x: Union[Tensor, PairTensor], edge_index: SparseTensor, edge_attr: OptTensor = None, return_attention_weights: bool = None, ) -> Tuple[Tensor, SparseTensor]: pass
[docs] def forward( # noqa: F811 self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None, return_attention_weights: Optional[bool] = None, ) -> Union[ Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]], Tuple[Tensor, SparseTensor], ]: r"""Runs the forward pass of the module. Args: x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ H, C = self.heads, self.out_channels if isinstance(x, Tensor): x = (x, x) query = self.lin_query(x[1]).view(-1, H, C) key = self.lin_key(x[0]).view(-1, H, C) value = self.lin_value(x[0]).view(-1, H, C) # propagate_type: (query: Tensor, key:Tensor, value: Tensor, # edge_attr: OptTensor) out = self.propagate(edge_index, query=query, key=key, value=value, edge_attr=edge_attr) alpha = self._alpha self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.root_weight: x_r = self.lin_skip(x[1]) if self.lin_beta is not None: beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1)) beta = beta.sigmoid() out = beta * x_r + (1 - beta) * out else: out = out + x_r if isinstance(return_attention_weights, bool): assert alpha is not None if isinstance(edge_index, Tensor): return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: if self.lin_edge is not None: assert edge_attr is not None edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels) key_j = key_j + edge_attr alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels) alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha alpha = F.dropout(alpha, p=self.dropout, training=self.training) out = value_j if edge_attr is not None: out = out + edge_attr out = out * alpha.view(-1, self.heads, 1) return out def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})')