Source code for torch_geometric.nn.conv.point_transformer_conv

from typing import Callable, Optional, Tuple, Union

from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
    Adj,
    OptTensor,
    PairTensor,
    SparseTensor,
    torch_sparse,
)
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax


[docs]class PointTransformerConv(MessagePassing): r"""The Point Transformer layer from the `"Point Transformer" <https://arxiv.org/abs/2012.09164>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3 \mathbf{x}_j + \delta_{ij} \right), where the attention coefficients :math:`\alpha_{i,j}` and positional embedding :math:`\delta_{ij}` are computed as .. math:: \alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta} (\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j + \delta_{i,j}) \right) and .. math:: \delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j), with :math:`\gamma_\mathbf{\Theta}` and :math:`h_\mathbf{\Theta}` denoting neural networks, *i.e.* MLPs, and :math:`\mathbf{P} \in \mathbb{R}^{N \times D}` defines the position of each point. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. pos_nn (torch.nn.Module, optional): A neural network :math:`h_\mathbf{\Theta}` which maps relative spatial coordinates :obj:`pos_j - pos_i` of shape :obj:`[-1, 3]` to shape :obj:`[-1, out_channels]`. Will default to a :class:`torch.nn.Linear` transformation if not further specified. (default: :obj:`None`) attn_nn (torch.nn.Module, optional): A neural network :math:`\gamma_\mathbf{\Theta}` which maps transformed node features of shape :obj:`[-1, out_channels]` to shape :obj:`[-1, out_channels]`. (default: :obj:`None`) add_self_loops (bool, optional) : If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, positions :math:`(|\mathcal{V}|, 3)` or :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, in_channels: Union[int, Tuple[int, int]], out_channels: int, pos_nn: Optional[Callable] = None, attn_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.add_self_loops = add_self_loops if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.pos_nn = pos_nn if self.pos_nn is None: self.pos_nn = Linear(3, out_channels) self.attn_nn = attn_nn self.lin = Linear(in_channels[0], out_channels, bias=False) self.lin_src = Linear(in_channels[0], out_channels, bias=False) self.lin_dst = Linear(in_channels[1], out_channels, bias=False) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() reset(self.pos_nn) if self.attn_nn is not None: reset(self.attn_nn) self.lin.reset_parameters() self.lin_src.reset_parameters() self.lin_dst.reset_parameters()
[docs] def forward( self, x: Union[Tensor, PairTensor], pos: Union[Tensor, PairTensor], edge_index: Adj, ) -> Tensor: if isinstance(x, Tensor): alpha = (self.lin_src(x), self.lin_dst(x)) x = (self.lin(x), x) else: alpha = (self.lin_src(x[0]), self.lin_dst(x[1])) x = (self.lin(x[0]), x[1]) if isinstance(pos, Tensor): pos = (pos, pos) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops( edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0))) elif isinstance(edge_index, SparseTensor): edge_index = torch_sparse.set_diag(edge_index) # propagate_type: (x: PairTensor, pos: PairTensor, alpha: PairTensor) out = self.propagate(edge_index, x=x, pos=pos, alpha=alpha) return out
def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor, alpha_i: Tensor, alpha_j: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: delta = self.pos_nn(pos_i - pos_j) alpha = alpha_i - alpha_j + delta if self.attn_nn is not None: alpha = self.attn_nn(alpha) alpha = softmax(alpha, index, ptr, size_i) return alpha * (x_j + delta) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})')