Source code for torch_geometric.nn.conv.ppf_conv

import torch
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops

from ..inits import reset

def point_pair_features(pos_i, pos_j, norm_i, norm_j):
    def get_angle(v1, v2):
        return torch.atan2(
            torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))

    pseudo = pos_j - pos_i
    return torch.stack([
        pseudo.norm(p=2, dim=1),
        get_angle(norm_i, pseudo),
        get_angle(norm_j, pseudo),
        get_angle(norm_i, norm_j)

[docs]class PPFConv(MessagePassing): r"""The PPFNet operator from the `"PPFNet: Global Context Aware Local Features for Robust 3D Point Matching" <>`_ paper .. math:: \mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \| \mathbf{d_{j,i}} \|, \angle(\mathbf{n}_i, \mathbf{d_{j,i}}), \angle(\mathbf{n}_j, \mathbf{d_{j,i}}), \angle(\mathbf{n}_i, \mathbf{n}_j) \right) where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *.i.e.* MLPs, which takes in node features and :class:`torch_geometric.transforms.PointPairFeatures`. Args: local_nn (torch.nn.Module, optional): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` and relative spatial coordinates :obj:`pos_j - pos_i` of shape :obj:`[-1, in_channels + num_dimensions]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. (default: :obj:`None`) global_nn (torch.nn.Module, optional): A neural network :math:`\gamma_{\mathbf{\Theta}}` that maps aggregated node features of shape :obj:`[-1, out_channels]` to shape :obj:`[-1, final_out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, local_nn=None, global_nn=None, **kwargs): super(PPFConv, self).__init__(aggr='max', **kwargs) self.local_nn = local_nn self.global_nn = global_nn self.reset_parameters()
[docs] def reset_parameters(self): reset(self.local_nn) reset(self.global_nn)
[docs] def forward(self, x, pos, norm, edge_index): r""" Args: x (Tensor): The node feature matrix. Allowed to be :obj:`None`. pos (Tensor or tuple): The node position matrix. Either given as tensor for use in general message passing or as tuple for use in message passing in bipartite graphs. norm (Tensor or tuple): The normal vectors of each node. Either given as tensor for use in general message passing or as tuple for use in message passing in bipartite graphs. edge_index (LongTensor): The edge indices. """ if torch.is_tensor(pos): # Add self-loops for symmetric adjacencies. edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=pos.size(0)) return self.propagate(edge_index, x=x, pos=pos, norm=norm)
def message(self, x_j, pos_i, pos_j, norm_i, norm_j): msg = point_pair_features(pos_i, pos_j, norm_i, norm_j) if x_j is not None: msg =[x_j, msg], dim=1) if self.local_nn is not None: msg = self.local_nn(msg) return msg def update(self, aggr_out): if self.global_nn is not None: aggr_out = self.global_nn(aggr_out) return aggr_out def __repr__(self): return '{}(local_nn={}, global_nn={})'.format( self.__class__.__name__, self.local_nn, self.global_nn)