from typing import Optional
from math import ceil
import torch
from torch import Tensor
from torch.nn import Sequential as S, Linear as L, BatchNorm1d as BN
from torch.nn import ELU, Conv1d
from torch_geometric.nn import Reshape
from ..inits import reset
try:
from torch_cluster import knn_graph
except ImportError:
knn_graph = None
[docs]class XConv(torch.nn.Module):
r"""The convolutional operator on :math:`\mathcal{X}`-transformed points
from the `"PointCNN: Convolution On X-Transformed Points"
<https://arxiv.org/abs/1801.07791>`_ paper
.. math::
\mathbf{x}^{\prime}_i = \mathrm{Conv}\left(\mathbf{K},
\gamma_{\mathbf{\Theta}}(\mathbf{P}_i - \mathbf{p}_i) \times
\left( h_\mathbf{\Theta}(\mathbf{P}_i - \mathbf{p}_i) \, \Vert \,
\mathbf{x}_i \right) \right),
where :math:`\mathbf{K}` and :math:`\mathbf{P}_i` denote the trainable
filter and neighboring point positions of :math:`\mathbf{x}_i`,
respectively.
:math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` describe
neural networks, *i.e.* MLPs, where :math:`h_{\mathbf{\Theta}}`
individually lifts each point into a higher-dimensional space, and
:math:`\gamma_{\mathbf{\Theta}}` computes the :math:`\mathcal{X}`-
transformation matrix based on *all* points in a neighborhood.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
dim (int): Point cloud dimensionality.
kernel_size (int): Size of the convolving kernel, *i.e.* number of
neighbors including self-loops.
hidden_channels (int, optional): Output size of
:math:`h_{\mathbf{\Theta}}`, *i.e.* dimensionality of lifted
points. If set to :obj:`None`, will be automatically set to
:obj:`in_channels / 4`. (default: :obj:`None`)
dilation (int, optional): The factor by which the neighborhood is
extended, from which :obj:`kernel_size` neighbors are then
uniformly sampled. Can be interpreted as the dilation rate of
classical convolutional operators. (default: :obj:`1`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
num_workers (int): Number of workers to use for k-NN computation.
Has no effect in case :obj:`batch` is not :obj:`None`, or the input
lies on the GPU. (default: :obj:`1`)
"""
def __init__(self, in_channels: int, out_channels: int, dim: int,
kernel_size: int, hidden_channels: Optional[int] = None,
dilation: int = 1, bias: bool = True, num_workers: int = 1):
super(XConv, self).__init__()
if knn_graph is None:
raise ImportError('`XConv` requires `torch-cluster`.')
self.in_channels = in_channels
if hidden_channels is None:
hidden_channels = in_channels // 4
assert hidden_channels > 0
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.dim = dim
self.kernel_size = kernel_size
self.dilation = dilation
self.num_workers = num_workers
C_in, C_delta, C_out = in_channels, hidden_channels, out_channels
D, K = dim, kernel_size
self.mlp1 = S(
L(dim, C_delta),
ELU(),
BN(C_delta),
L(C_delta, C_delta),
ELU(),
BN(C_delta),
Reshape(-1, K, C_delta),
)
self.mlp2 = S(
L(D * K, K**2),
ELU(),
BN(K**2),
Reshape(-1, K, K),
Conv1d(K, K**2, K, groups=K),
ELU(),
BN(K**2),
Reshape(-1, K, K),
Conv1d(K, K**2, K, groups=K),
BN(K**2),
Reshape(-1, K, K),
)
C_in = C_in + C_delta
depth_multiplier = int(ceil(C_out / C_in))
self.conv = S(
Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in),
Reshape(-1, C_in * depth_multiplier),
L(C_in * depth_multiplier, C_out, bias=bias),
)
self.reset_parameters()
[docs] def reset_parameters(self):
reset(self.mlp1)
reset(self.mlp2)
reset(self.conv)
[docs] def forward(self, x: Tensor, pos: Tensor, batch: Optional[Tensor] = None):
""""""
pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
(N, D), K = pos.size(), self.kernel_size
edge_index = knn_graph(pos, K * self.dilation, batch, loop=True,
flow='target_to_source',
num_workers=self.num_workers)
if self.dilation > 1:
edge_index = edge_index[:, ::self.dilation]
row, col = edge_index[0], edge_index[1]
pos = pos[col] - pos[row]
x_star = self.mlp1(pos)
if x is not None:
x = x.unsqueeze(-1) if x.dim() == 1 else x
x = x[col].view(N, K, self.in_channels)
x_star = torch.cat([x_star, x], dim=-1)
x_star = x_star.transpose(1, 2).contiguous()
transform_matrix = self.mlp2(pos.view(N, K * D))
x_transformed = torch.matmul(x_star, transform_matrix)
out = self.conv(x_transformed)
return out
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)