from typing import Callable, Optional, Union
import torch
from torch import Tensor
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor
from torch_geometric.utils import add_self_loops, remove_self_loops
from ..inits import reset
[docs]class PointNetConv(MessagePassing):
r"""The PointNet set layer from the `"PointNet: Deep Learning on Point Sets
for 3D Classification and Segmentation"
<https://arxiv.org/abs/1612.00593>`_ and `"PointNet++: Deep Hierarchical
Feature Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ papers
.. math::
\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in
\mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j,
\mathbf{p}_j - \mathbf{p}_i) \right),
where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}`
denote neural networks, *i.e.* MLPs, and
:math:`\mathbf{P} \in \mathbb{R}^{N \times D}` defines the position of
each point.
Args:
local_nn (torch.nn.Module, optional): A neural network
:math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` and
relative spatial coordinates :obj:`pos_j - pos_i` of shape
:obj:`[-1, in_channels + num_dimensions]` to shape
:obj:`[-1, out_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
global_nn (torch.nn.Module, optional): A neural network
:math:`\gamma_{\mathbf{\Theta}}` that maps aggregated node features
of shape :obj:`[-1, out_channels]` to shape :obj:`[-1,
final_out_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
if bipartite,
positions :math:`(|\mathcal{V}|, 3)` or
:math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
"""
def __init__(self, local_nn: Optional[Callable] = None,
global_nn: Optional[Callable] = None,
add_self_loops: bool = True, **kwargs):
kwargs.setdefault('aggr', 'max')
super().__init__(**kwargs)
self.local_nn = local_nn
self.global_nn = global_nn
self.add_self_loops = add_self_loops
self.reset_parameters()
[docs] def reset_parameters(self):
reset(self.local_nn)
reset(self.global_nn)
[docs] def forward(self, x: Union[OptTensor, PairOptTensor],
pos: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:
""""""
if not isinstance(x, tuple):
x: PairOptTensor = (x, None)
if isinstance(pos, Tensor):
pos: PairTensor = (pos, pos)
if self.add_self_loops:
if isinstance(edge_index, Tensor):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(
edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0)))
elif isinstance(edge_index, SparseTensor):
edge_index = set_diag(edge_index)
# propagate_type: (x: PairOptTensor, pos: PairTensor)
out = self.propagate(edge_index, x=x, pos=pos, size=None)
if self.global_nn is not None:
out = self.global_nn(out)
return out
def message(self, x_j: Optional[Tensor], pos_i: Tensor,
pos_j: Tensor) -> Tensor:
msg = pos_j - pos_i
if x_j is not None:
msg = torch.cat([x_j, msg], dim=1)
if self.local_nn is not None:
msg = self.local_nn(msg)
return msg
def __repr__(self) -> str:
return (f'{self.__class__.__name__}(local_nn={self.local_nn}, '
f'global_nn={self.global_nn})')
PointConv = PointNetConv