torch_geometric.nn.conv.TransformerConv

class TransformerConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, beta: bool = False, dropout: float = 0.0, edge_dim: Optional[int] = None, bias: bool = True, root_weight: bool = True, **kwargs)[source]

Bases: MessagePassing

The graph transformer operator from the “Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification” paper.

\[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},\]

where the attention coefficients \(\alpha_{i,j}\) are computed via multi-head dot product attention:

\[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} {\sqrt{d}} \right)\]
Parameters:
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • beta (bool, optional) –

    If set, will combine aggregation and skip information via

    \[\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}\]

    with \(\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 \mathbf{x}_i - \mathbf{m}_i ])\) (default: False)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • edge_dim (int, optional) –

    Edge feature dimensionality (in case there are any). Edge features are added to the keys after linear transformation, that is, prior to computing the attention dot product. They are also added to final values after the same linear transformation. The model is:

    \[\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} \right),\]

    where the attention coefficients \(\alpha_{i,j}\) are now computed via:

    \[\alpha_{i,j} = \textrm{softmax} \left( \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} {\sqrt{d}} \right)\]

    (default None)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • root_weight (bool, optional) – If set to False, the layer will not add the transformed root node features to the output and the option beta is set to False. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None, return_attention_weights: Optional[Tensor] = None) Tensor[source]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Tensor, edge_attr: Optional[Tensor] = None, return_attention_weights: bool = None) Tuple[Tensor, Tuple[Tensor, Tensor]]
forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: SparseTensor, edge_attr: Optional[Tensor] = None, return_attention_weights: bool = None) Tuple[Tensor, SparseTensor]

Runs the forward pass of the module.

Parameters:
  • x (torch.Tensor or (torch.Tensor, torch.Tensor)) – The input node features.

  • edge_index (torch.Tensor or SparseTensor) – The edge indices.

  • edge_attr (torch.Tensor, optional) – The edge features. (default: None)

  • return_attention_weights (bool, optional) – If set to True, will additionally return the tuple (edge_index, attention_weights), holding the computed attention weights for each edge. (default: None)

reset_parameters()[source]

Resets all learnable parameters of the module.