Source code for torch_geometric.nn.conv.transformer_conv

import math
from typing import Union, Tuple, Optional
from torch_geometric.typing import PairTensor, Adj, OptTensor

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax

[docs]class TransformerConv(MessagePassing): r"""The graph transformer operator from the `"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" <>`_ paper .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j}, where the attention coefficients :math:`\alpha_{i,j}` are computed via multi-head dot product attention: .. math:: \alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} {\sqrt{d}} \right) Args: in_channels (int or tuple): Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) beta (bool, optional): If set, will combine aggregation and skip information via .. math:: \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i} with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} [ \mathbf{x}_i, \mathbf{m}_i, \mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). Edge features are added to the keys after linear transformation, that is, prior to computing the attention dot product. They are also added to final values after the same linear transformation. The model is: .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} \right), where the attention coefficients :math:`\alpha_{i,j}` are now computed via: .. math:: \alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} {\sqrt{d}} \right) (default :obj:`None`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) root_weight (bool, optional): If set to :obj:`False`, the layer will not add the transformed root node features to the output and the option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ _alpha: OptTensor def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, beta: bool = False, dropout: float = 0., edge_dim: Optional[int] = None, bias: bool = True, root_weight: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super(TransformerConv, self).__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.beta = beta and root_weight self.root_weight = root_weight self.concat = concat self.dropout = dropout self.edge_dim = edge_dim if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lin_key = Linear(in_channels[0], heads * out_channels) self.lin_query = Linear(in_channels[1], heads * out_channels) self.lin_value = Linear(in_channels[0], heads * out_channels) if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False) else: self.lin_edge = self.register_parameter('lin_edge', None) if concat: self.lin_skip = Linear(in_channels[1], heads * out_channels, bias=bias) if self.beta: self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False) else: self.lin_beta = self.register_parameter('lin_beta', None) else: self.lin_skip = Linear(in_channels[1], out_channels, bias=bias) if self.beta: self.lin_beta = Linear(3 * out_channels, 1, bias=False) else: self.lin_beta = self.register_parameter('lin_beta', None) self.reset_parameters()
[docs] def reset_parameters(self): self.lin_key.reset_parameters() self.lin_query.reset_parameters() self.lin_value.reset_parameters() if self.edge_dim: self.lin_edge.reset_parameters() self.lin_skip.reset_parameters() if self.beta: self.lin_beta.reset_parameters()
[docs] def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None): """""" if isinstance(x, Tensor): x: PairTensor = (x, x) # propagate_type: (x: PairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None) if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.root_weight: x_r = self.lin_skip(x[1]) if self.lin_beta is not None: beta = self.lin_beta([out, x_r, out - x_r], dim=-1)) beta = beta.sigmoid() out = beta * x_r + (1 - beta) * out else: out += x_r return out
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: query = self.lin_query(x_i).view(-1, self.heads, self.out_channels) key = self.lin_key(x_j).view(-1, self.heads, self.out_channels) if self.lin_edge is not None: assert edge_attr is not None edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, self.out_channels) key += edge_attr alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, out = self.lin_value(x_j).view(-1, self.heads, self.out_channels) if edge_attr is not None: out += edge_attr out *= alpha.view(-1, self.heads, 1) return out def __repr__(self): return '{}({}, {}, heads={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads)