torch_geometric.nn.pool.fps

fps(x: Tensor, batch: Optional[Tensor] = None, ratio: float = 0.5, random_start: bool = True) Tensor[source]

A sampling algorithm from the “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” paper, which iteratively samples the most distant point with regard to the rest points.

import torch
from torch_geometric.nn import fps

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
index = fps(x, batch, ratio=0.5)
Parameters
  • x (torch.Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • batch (torch.Tensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • ratio (float, optional) – Sampling ratio. (default: 0.5)

  • random_start (bool, optional) – If set to False, use the first node in \(\mathbf{X}\) as starting node. (default: obj:True)

Return type

torch.Tensor