torch_geometric.nn.conv.PointTransformerConv

class PointTransformerConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, pos_nn: Optional[Callable] = None, attn_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]

Bases: MessagePassing

The Point Transformer layer from the “Point Transformer” paper.

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3 \mathbf{x}_j + \delta_{ij} \right),\]

where the attention coefficients \(\alpha_{i,j}\) and positional embedding \(\delta_{ij}\) are computed as

\[\alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta} (\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j + \delta_{i,j}) \right)\]

and

\[\delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j),\]

with \(\gamma_\mathbf{\Theta}\) and \(h_\mathbf{\Theta}\) denoting neural networks, i.e. MLPs, and \(\mathbf{P} \in \mathbb{R}^{N \times D}\) defines the position of each point.

Parameters:
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • pos_nn (torch.nn.Module, optional) – A neural network \(h_\mathbf{\Theta}\) which maps relative spatial coordinates pos_j - pos_i of shape [-1, 3] to shape [-1, out_channels]. Will default to a torch.nn.Linear transformation if not further specified. (default: None)

  • attn_nn (torch.nn.Module, optional) – A neural network \(\gamma_\mathbf{\Theta}\) which maps transformed node features of shape [-1, out_channels] to shape [-1, out_channels]. (default: None)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, positions \((|\mathcal{V}|, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

forward(x: Union[Tensor, Tuple[Tensor, Tensor]], pos: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor]) Tensor[source]

Runs the forward pass of the module.

reset_parameters()[source]

Resets all learnable parameters of the module.