import torch
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj
[docs]class PointGNNConv(MessagePassing):
r"""The PointGNN operator from the `"Point-GNN: Graph Neural Network for
3D Object Detection in a Point Cloud" <https://arxiv.org/abs/2003.01251>`_
paper.
.. math::
\Delta \textrm{pos}_i &= h_{\mathbf{\Theta}}(\mathbf{x}_i)
\mathbf{e}_{j,i} &= f_{\mathbf{\Theta}}(\textrm{pos}_j -
\textrm{pos}_i + \Delta \textrm{pos}_i, \mathbf{x}_j)
\mathbf{x}^{\prime}_i &= g_{\mathbf{\Theta}}(\max_{j \in
\mathcal{N}(i)} \mathbf{e}_{j,i}) + \mathbf{x}_i
The relative position is used in the message passing step to introduce
global translation invariance.
To also counter shifts in the local neighborhood of the center node, the
authors propose to utilize an alignment offset.
The graph should be statically constructed using radius-based cutoff.
Args:
mlp_h (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}`
that maps node features of size :math:`F_{in}` to three-dimensional
coordination offsets :math:`\Delta \textrm{pos}_i`.
mlp_f (torch.nn.Module): A neural network :math:`f_{\mathbf{\Theta}}`
that computes :math:`\mathbf{e}_{j,i}` from the features of
neighbors of size :math:`F_{in}` and the three-dimensional vector
:math:`\textrm{pos_j} - \textrm{pos_i} + \Delta \textrm{pos}_i`.
mlp_g (torch.nn.Module): A neural network :math:`g_{\mathbf{\Theta}}`
that maps the aggregated edge features back to :math:`F_{in}`.
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
positions :math:`(|\mathcal{V}|, 3)`,
edge indices :math:`(2, |\mathcal{E}|)`,
- **output:** node features :math:`(|\mathcal{V}|, F_{in})`
"""
def __init__(
self,
mlp_h: torch.nn.Module,
mlp_f: torch.nn.Module,
mlp_g: torch.nn.Module,
**kwargs,
):
kwargs.setdefault('aggr', 'max')
super().__init__(**kwargs)
self.mlp_h = mlp_h
self.mlp_f = mlp_f
self.mlp_g = mlp_g
self.reset_parameters()
[docs] def reset_parameters(self):
super().reset_parameters()
reset(self.mlp_h)
reset(self.mlp_f)
reset(self.mlp_g)
[docs] def forward(self, x: Tensor, pos: Tensor, edge_index: Adj) -> Tensor:
# propagate_type: (x: Tensor, pos: Tensor)
out = self.propagate(edge_index, x=x, pos=pos)
out = self.mlp_g(out)
return x + out
def message(self, pos_j: Tensor, pos_i: Tensor, x_i: Tensor,
x_j: Tensor) -> Tensor:
delta = self.mlp_h(x_i)
e = torch.cat([pos_j - pos_i + delta, x_j], dim=-1)
return self.mlp_f(e)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}(\n'
f' mlp_h={self.mlp_h},\n'
f' mlp_f={self.mlp_f},\n'
f' mlp_g={self.mlp_g},\n'
f')')