# Source code for torch_geometric.nn.conv.transformer_conv

import math
import typing
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
NoneType,
OptTensor,
PairTensor,
SparseTensor,
)
from torch_geometric.utils import softmax

if typing.TYPE_CHECKING:
else:

[docs]class TransformerConv(MessagePassing):
r"""The graph transformer operator from the "Masked Label Prediction:
Unified Message Passing Model for Semi-Supervised Classification"
<https://arxiv.org/abs/2009.03509>_ paper.

.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},

where the attention coefficients :math:\alpha_{i,j} are computed via

.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
{\sqrt{d}} \right)

Args:
in_channels (int or tuple): Size of each input sample, or :obj:-1 to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
(default: :obj:1)
concat (bool, optional): If set to :obj:False, the multi-head
attentions are averaged instead of concatenated.
(default: :obj:True)
beta (bool, optional): If set, will combine aggregation and
skip information via

.. math::
\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}

with :math:\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1
\mathbf{x}_i - \mathbf{m}_i ]) (default: :obj:False)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:0)
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). Edge features are added to the keys after
linear transformation, that is, prior to computing the
attention dot product. They are also added to final values
after the same linear transformation. The model is:

.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
\right),

where the attention coefficients :math:\alpha_{i,j} are now
computed via:

.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
{\sqrt{d}} \right)

(default :obj:None)
bias (bool, optional): If set to :obj:False, the layer will not learn
an additive bias. (default: :obj:True)
root_weight (bool, optional): If set to :obj:False, the layer will
not add the transformed root node features to the output and the
option  :attr:beta is set to :obj:False. (default: :obj:True)
:class:torch_geometric.nn.conv.MessagePassing.
"""
_alpha: OptTensor

def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
concat: bool = True,
beta: bool = False,
dropout: float = 0.,
edge_dim: Optional[int] = None,
bias: bool = True,
root_weight: bool = True,
**kwargs,
):
super().__init__(node_dim=0, **kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.beta = beta and root_weight
self.root_weight = root_weight
self.concat = concat
self.dropout = dropout
self.edge_dim = edge_dim
self._alpha = None

if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)

self.lin_key = Linear(in_channels[0], heads * out_channels)
self.lin_query = Linear(in_channels[1], heads * out_channels)
self.lin_value = Linear(in_channels[0], heads * out_channels)
if edge_dim is not None:
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
else:
self.lin_edge = self.register_parameter('lin_edge', None)

if concat:
self.lin_skip = Linear(in_channels[1], heads * out_channels,
bias=bias)
if self.beta:
self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
else:
self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
if self.beta:
self.lin_beta = Linear(3 * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)

self.reset_parameters()

[docs]    def reset_parameters(self):
super().reset_parameters()
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
if self.edge_dim:
self.lin_edge.reset_parameters()
self.lin_skip.reset_parameters()
if self.beta:
self.lin_beta.reset_parameters()

def forward(
self,
x: Union[Tensor, PairTensor],
edge_attr: OptTensor = None,
return_attention_weights: NoneType = None,
) -> Tensor:
pass

def forward(  # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: Tensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
pass

def forward(  # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: SparseTensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, SparseTensor]:
pass

[docs]    def forward(  # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_attr: OptTensor = None,
return_attention_weights: Optional[bool] = None,
) -> Union[
Tensor,
Tuple[Tensor, Tuple[Tensor, Tensor]],
Tuple[Tensor, SparseTensor],
]:
r"""Runs the forward pass of the module.

Args:
x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node
features.
edge_index (torch.Tensor or SparseTensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:None)
return_attention_weights (bool, optional): If set to :obj:True,
:obj:(edge_index, attention_weights), holding the computed
attention weights for each edge. (default: :obj:None)
"""

if isinstance(x, Tensor):
x = (x, x)

query = self.lin_query(x[1]).view(-1, H, C)
key = self.lin_key(x[0]).view(-1, H, C)
value = self.lin_value(x[0]).view(-1, H, C)

# propagate_type: (query: Tensor, key:Tensor, value: Tensor,
#                  edge_attr: OptTensor)
out = self.propagate(edge_index, query=query, key=key, value=value,
edge_attr=edge_attr)

alpha = self._alpha
self._alpha = None

if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)

if self.root_weight:
x_r = self.lin_skip(x[1])
if self.lin_beta is not None:
beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
beta = beta.sigmoid()
out = beta * x_r + (1 - beta) * out
else:
out = out + x_r

if isinstance(return_attention_weights, bool):
assert alpha is not None
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
else:
return out

def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:

if self.lin_edge is not None:
assert edge_attr is not None
self.out_channels)
key_j = key_j + edge_attr

alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)

out = value_j
if edge_attr is not None:
out = out + edge_attr

out = out * alpha.view(-1, self.heads, 1)
return out

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '