torch_geometric.nn.conv.PointTransformerConv
- class PointTransformerConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, pos_nn: Optional[Callable] = None, attn_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]
Bases:
MessagePassing
The Point Transformer layer from the “Point Transformer” paper.
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3 \mathbf{x}_j + \delta_{ij} \right),\]where the attention coefficients \(\alpha_{i,j}\) and positional embedding \(\delta_{ij}\) are computed as
\[\alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta} (\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j + \delta_{i,j}) \right)\]and
\[\delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j),\]with \(\gamma_\mathbf{\Theta}\) and \(h_\mathbf{\Theta}\) denoting neural networks, i.e. MLPs, and \(\mathbf{P} \in \mathbb{R}^{N \times D}\) defines the position of each point.
- Parameters:
in_channels (int or tuple) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.out_channels (int) – Size of each output sample.
pos_nn (torch.nn.Module, optional) – A neural network \(h_\mathbf{\Theta}\) which maps relative spatial coordinates
pos_j - pos_i
of shape[-1, 3]
to shape[-1, out_channels]
. Will default to atorch.nn.Linear
transformation if not further specified. (default:None
)attn_nn (torch.nn.Module, optional) – A neural network \(\gamma_\mathbf{\Theta}\) which maps transformed node features of shape
[-1, out_channels]
to shape[-1, out_channels]
. (default:None
)add_self_loops (bool, optional) – If set to
False
, will not add self-loops to the input graph. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, positions \((|\mathcal{V}|, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite