import torch
import scipy.spatial
if torch.cuda.is_available():
import torch_cluster.knn_cuda
[docs]def knn(x, y, k, batch_x=None, batch_y=None, cosine=False):
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
y (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{M \times F}`.
k (int): The number of neighbors.
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. (default: :obj:`None`)
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
:rtype: :class:`LongTensor`
.. testsetup::
import torch
from torch_cluster import knn
.. testcode::
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
>>> batch_x = torch.tensor([0, 0, 0, 0])
>>> y = torch.Tensor([[-1, 0], [1, 0]])
>>> batch_x = torch.tensor([0, 0])
>>> assign_index = knn(x, y, 2, batch_x, batch_y)
"""
if batch_x is None:
batch_x = x.new_zeros(x.size(0), dtype=torch.long)
if batch_y is None:
batch_y = y.new_zeros(y.size(0), dtype=torch.long)
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
if x.is_cuda:
return torch_cluster.knn_cuda.knn(x, y, k, batch_x, batch_y, cosine)
if cosine:
raise NotImplementedError('Cosine distance not implemented for CPU')
# Rescale x and y.
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy
max_xy = max(x.max().item(), y.max().item())
x, y, = x / max_xy, y / max_xy
# Concat batch/features to ensure no cross-links between examples exist.
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
tree = scipy.spatial.cKDTree(x.detach().numpy())
dist, col = tree.query(y.detach().cpu(), k=k,
distance_upper_bound=x.size(1))
dist = torch.from_numpy(dist).to(x.dtype)
col = torch.from_numpy(col).to(torch.long)
row = torch.arange(col.size(0), dtype=torch.long).view(-1, 1).repeat(1, k)
mask = ~torch.isinf(dist).view(-1)
row, col = row.view(-1)[mask], col.view(-1)[mask]
return torch.stack([row, col], dim=0)
[docs]def knn_graph(x, k, batch=None, loop=False, flow='source_to_target',
cosine=False):
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
k (int): The number of neighbors.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
flow (string, optional): The flow direction when using in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
cosine (boolean, optional): If :obj:`True`, will use the cosine
distance instead of euclidean distance to find nearest neighbors.
(default: :obj:`False`)
:rtype: :class:`LongTensor`
.. testsetup::
import torch
from torch_cluster import knn_graph
.. testcode::
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
>>> batch = torch.tensor([0, 0, 0, 0])
>>> edge_index = knn_graph(x, k=2, batch=batch, loop=False)
"""
assert flow in ['source_to_target', 'target_to_source']
row, col = knn(x, x, k if loop else k + 1, batch, batch, cosine=cosine)
row, col = (col, row) if flow == 'source_to_target' else (row, col)
if not loop:
mask = row != col
row, col = row[mask], col[mask]
return torch.stack([row, col], dim=0)