Source code for torch_geometric.nn.conv.point_transformer_conv
from typing import Callable, Optional, Tuple, Union
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
Adj,
OptTensor,
PairTensor,
SparseTensor,
torch_sparse,
)
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax
[docs]class PointTransformerConv(MessagePassing):
r"""The Point Transformer layer from the `"Point Transformer"
<https://arxiv.org/abs/2012.09164>`_ paper.
.. math::
\mathbf{x}^{\prime}_i = \sum_{j \in
\mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3
\mathbf{x}_j + \delta_{ij} \right),
where the attention coefficients :math:`\alpha_{i,j}` and
positional embedding :math:`\delta_{ij}` are computed as
.. math::
\alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta}
(\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j +
\delta_{i,j}) \right)
and
.. math::
\delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j),
with :math:`\gamma_\mathbf{\Theta}` and :math:`h_\mathbf{\Theta}`
denoting neural networks, *i.e.* MLPs, and
:math:`\mathbf{P} \in \mathbb{R}^{N \times D}` defines the position of
each point.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
pos_nn (torch.nn.Module, optional): A neural network
:math:`h_\mathbf{\Theta}` which maps relative spatial coordinates
:obj:`pos_j - pos_i` of shape :obj:`[-1, 3]` to shape
:obj:`[-1, out_channels]`.
Will default to a :class:`torch.nn.Linear` transformation if not
further specified. (default: :obj:`None`)
attn_nn (torch.nn.Module, optional): A neural network
:math:`\gamma_\mathbf{\Theta}` which maps transformed
node features of shape :obj:`[-1, out_channels]`
to shape :obj:`[-1, out_channels]`. (default: :obj:`None`)
add_self_loops (bool, optional) : If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
if bipartite,
positions :math:`(|\mathcal{V}|, 3)` or
:math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
"""
def __init__(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, pos_nn: Optional[Callable] = None,
attn_nn: Optional[Callable] = None,
add_self_loops: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.add_self_loops = add_self_loops
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.pos_nn = pos_nn
if self.pos_nn is None:
self.pos_nn = Linear(3, out_channels)
self.attn_nn = attn_nn
self.lin = Linear(in_channels[0], out_channels, bias=False)
self.lin_src = Linear(in_channels[0], out_channels, bias=False)
self.lin_dst = Linear(in_channels[1], out_channels, bias=False)
self.reset_parameters()
[docs] def reset_parameters(self):
super().reset_parameters()
reset(self.pos_nn)
if self.attn_nn is not None:
reset(self.attn_nn)
self.lin.reset_parameters()
self.lin_src.reset_parameters()
self.lin_dst.reset_parameters()
[docs] def forward(
self,
x: Union[Tensor, PairTensor],
pos: Union[Tensor, PairTensor],
edge_index: Adj,
) -> Tensor:
if isinstance(x, Tensor):
alpha = (self.lin_src(x), self.lin_dst(x))
x = (self.lin(x), x)
else:
alpha = (self.lin_src(x[0]), self.lin_dst(x[1]))
x = (self.lin(x[0]), x[1])
if isinstance(pos, Tensor):
pos = (pos, pos)
if self.add_self_loops:
if isinstance(edge_index, Tensor):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(
edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0)))
elif isinstance(edge_index, SparseTensor):
edge_index = torch_sparse.set_diag(edge_index)
# propagate_type: (x: PairTensor, pos: PairTensor, alpha: PairTensor)
out = self.propagate(edge_index, x=x, pos=pos, alpha=alpha)
return out
def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor,
alpha_i: Tensor, alpha_j: Tensor, index: Tensor,
ptr: OptTensor, size_i: Optional[int]) -> Tensor:
delta = self.pos_nn(pos_i - pos_j)
alpha = alpha_i - alpha_j + delta
if self.attn_nn is not None:
alpha = self.attn_nn(alpha)
alpha = softmax(alpha, index, ptr, size_i)
return alpha * (x_j + delta)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels})')