torch_geometric.nn.pool.nearest

nearest(x: Tensor, y: Tensor, batch_x: Optional[Tensor] = None, batch_y: Optional[Tensor] = None) Tensor[source]

Finds for each element in y the k nearest point in x.

import torch
from torch_geometric.nn import nearest

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
cluster = nearest(x, y, batch_x, batch_y)
Parameters
  • x (torch.Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • y (torch.Tensor) – Node feature matrix \(\mathbf{Y} \in \mathbb{R}^{M \times F}\).

  • batch_x (torch.Tensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • batch_y (torch.Tensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M\), which assigns each node to a specific example. (default: None)

Return type

torch.Tensor