Source code for torch_geometric.nn.conv.ppf_conv

from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
from torch_geometric.utils import add_self_loops, remove_self_loops

def get_angle(v1: Tensor, v2: Tensor) -> Tensor:
    return torch.atan2(
        torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))

def point_pair_features(pos_i: Tensor, pos_j: Tensor, normal_i: Tensor,
                        normal_j: Tensor) -> Tensor:
    pseudo = pos_j - pos_i
    return torch.stack([
        pseudo.norm(p=2, dim=1),
        get_angle(normal_i, pseudo),
        get_angle(normal_j, pseudo),
        get_angle(normal_i, normal_j)
    ], dim=1)

[docs]class PPFConv(MessagePassing): r"""The PPFNet operator from the `"PPFNet: Global Context Aware Local Features for Robust 3D Point Matching" <>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \| \mathbf{d_{j,i}} \|, \angle(\mathbf{n}_i, \mathbf{d_{j,i}}), \angle(\mathbf{n}_j, \mathbf{d_{j,i}}), \angle(\mathbf{n}_i, \mathbf{n}_j) \right) where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *.i.e.* MLPs, which takes in node features and :class:`torch_geometric.transforms.PointPairFeatures`. Args: local_nn (torch.nn.Module, optional): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` and relative spatial coordinates :obj:`pos_j - pos_i` of shape :obj:`[-1, in_channels + num_dimensions]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. (default: :obj:`None`) global_nn (torch.nn.Module, optional): A neural network :math:`\gamma_{\mathbf{\Theta}}` that maps aggregated node features of shape :obj:`[-1, out_channels]` to shape :obj:`[-1, final_out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. (default: :obj:`None`) add_self_loops (bool, optional): If set to :obj:`False`, will not add self-loops to the input graph. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, positions :math:`(|\mathcal{V}|, 3)` or :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, point normals :math:`(|\mathcal{V}, 3)` or :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, local_nn: Optional[Callable] = None, global_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs): kwargs.setdefault('aggr', 'max') super().__init__(**kwargs) self.local_nn = local_nn self.global_nn = global_nn self.add_self_loops = add_self_loops self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() reset(self.local_nn) reset(self.global_nn)
[docs] def forward( self, x: Union[OptTensor, PairOptTensor], pos: Union[Tensor, PairTensor], normal: Union[Tensor, PairTensor], edge_index: Adj, ) -> Tensor: if not isinstance(x, tuple): x = (x, None) if isinstance(pos, Tensor): pos = (pos, pos) if isinstance(normal, Tensor): normal = (normal, normal) if self.add_self_loops: if isinstance(edge_index, Tensor): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=pos[1].size(0)) elif isinstance(edge_index, SparseTensor): edge_index = torch_sparse.set_diag(edge_index) # propagate_type: (x: PairOptTensor, pos: PairTensor, # normal: PairTensor) out = self.propagate(edge_index, x=x, pos=pos, normal=normal) if self.global_nn is not None: out = self.global_nn(out) return out
def message(self, x_j: OptTensor, pos_i: Tensor, pos_j: Tensor, normal_i: Tensor, normal_j: Tensor) -> Tensor: msg = point_pair_features(pos_i, pos_j, normal_i, normal_j) if x_j is not None: msg =[x_j, msg], dim=1) if self.local_nn is not None: msg = self.local_nn(msg) return msg def __repr__(self) -> str: return (f'{self.__class__.__name__}(local_nn={self.local_nn}, ' f'global_nn={self.global_nn})')