torch_geometric.nn.pool.knn_graph
- knn_graph(x: Tensor, k: int, batch: Optional[Tensor] = None, loop: bool = False, flow: str = 'source_to_target', cosine: bool = False, num_workers: int = 1, batch_size: Optional[int] = None) Tensor [source]
Computes graph edges to the nearest
k
points.import torch from torch_geometric.nn import knn_graph x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch = torch.tensor([0, 0, 0, 0]) edge_index = knn_graph(x, k=2, batch=batch, loop=False)
- Parameters
x (torch.Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
k (int) – The number of neighbors.
batch (torch.Tensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default:
None
)loop (bool, optional) – If
True
, the graph will contain self-loops. (default:False
)flow (str, optional) – The flow direction when using in combination with message passing (
"source_to_target"
or"target_to_source"
). (default:"source_to_target"
)cosine (bool, optional) – If
True
, will use the cosine distance instead of euclidean distance to find nearest neighbors. (default:False
)num_workers (int, optional) – Number of workers to use for computation. Has no effect in case
batch
is notNone
, or the input lies on the GPU. (default:1
)batch_size (int, optional) – The number of examples \(B\). Automatically calculated if not given. (default:
None
)
- Return type