Source code for torch_geometric.nn.conv.transformer_conv
import math
import typing
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
Adj,
NoneType,
OptTensor,
PairTensor,
SparseTensor,
)
from torch_geometric.utils import softmax
if typing.TYPE_CHECKING:
from typing import overload
else:
from torch.jit import _overload_method as overload
[docs]class TransformerConv(MessagePassing):
r"""The graph transformer operator from the `"Masked Label Prediction:
Unified Message Passing Model for Semi-Supervised Classification"
<https://arxiv.org/abs/2009.03509>`_ paper.
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},
where the attention coefficients :math:`\alpha_{i,j}` are computed via
multi-head dot product attention:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
{\sqrt{d}} \right)
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
concat (bool, optional): If set to :obj:`False`, the multi-head
attentions are averaged instead of concatenated.
(default: :obj:`True`)
beta (bool, optional): If set, will combine aggregation and
skip information via
.. math::
\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}
with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1
\mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). Edge features are added to the keys after
linear transformation, that is, prior to computing the
attention dot product. They are also added to final values
after the same linear transformation. The model is:
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
\right),
where the attention coefficients :math:`\alpha_{i,j}` are now
computed via:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
{\sqrt{d}} \right)
(default :obj:`None`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add the transformed root node features to the output and the
option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
_alpha: OptTensor
def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
heads: int = 1,
concat: bool = True,
beta: bool = False,
dropout: float = 0.,
edge_dim: Optional[int] = None,
bias: bool = True,
root_weight: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
super().__init__(node_dim=0, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.beta = beta and root_weight
self.root_weight = root_weight
self.concat = concat
self.dropout = dropout
self.edge_dim = edge_dim
self._alpha = None
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.lin_key = Linear(in_channels[0], heads * out_channels)
self.lin_query = Linear(in_channels[1], heads * out_channels)
self.lin_value = Linear(in_channels[0], heads * out_channels)
if edge_dim is not None:
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
else:
self.lin_edge = self.register_parameter('lin_edge', None)
if concat:
self.lin_skip = Linear(in_channels[1], heads * out_channels,
bias=bias)
if self.beta:
self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
else:
self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
if self.beta:
self.lin_beta = Linear(3 * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
self.reset_parameters()
[docs] def reset_parameters(self):
super().reset_parameters()
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
if self.edge_dim:
self.lin_edge.reset_parameters()
self.lin_skip.reset_parameters()
if self.beta:
self.lin_beta.reset_parameters()
@overload
def forward(
self,
x: Union[Tensor, PairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
return_attention_weights: NoneType = None,
) -> Tensor:
pass
@overload
def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: Tensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
pass
@overload
def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: SparseTensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, SparseTensor]:
pass
[docs] def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
return_attention_weights: Optional[bool] = None,
) -> Union[
Tensor,
Tuple[Tensor, Tuple[Tensor, Tensor]],
Tuple[Tensor, SparseTensor],
]:
r"""Runs the forward pass of the module.
Args:
x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node
features.
edge_index (torch.Tensor or SparseTensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
return_attention_weights (bool, optional): If set to :obj:`True`,
will additionally return the tuple
:obj:`(edge_index, attention_weights)`, holding the computed
attention weights for each edge. (default: :obj:`None`)
"""
H, C = self.heads, self.out_channels
if isinstance(x, Tensor):
x = (x, x)
query = self.lin_query(x[1]).view(-1, H, C)
key = self.lin_key(x[0]).view(-1, H, C)
value = self.lin_value(x[0]).view(-1, H, C)
# propagate_type: (query: Tensor, key:Tensor, value: Tensor,
# edge_attr: OptTensor)
out = self.propagate(edge_index, query=query, key=key, value=value,
edge_attr=edge_attr)
alpha = self._alpha
self._alpha = None
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.root_weight:
x_r = self.lin_skip(x[1])
if self.lin_beta is not None:
beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
beta = beta.sigmoid()
out = beta * x_r + (1 - beta) * out
else:
out = out + x_r
if isinstance(return_attention_weights, bool):
assert alpha is not None
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
else:
return out
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
if self.lin_edge is not None:
assert edge_attr is not None
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
self.out_channels)
key_j = key_j + edge_attr
alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = value_j
if edge_attr is not None:
out = out + edge_attr
out = out * alpha.view(-1, self.heads, 1)
return out
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, heads={self.heads})')