torch_geometric.nn.pool.KNNIndex

class KNNIndex(index_factory: Optional[str] = None, emb: Optional[Tensor] = None, reserve: Optional[int] = None)[source]

Bases: object

A base class to perform fast \(k\)-nearest neighbor search (\(k\)-NN) via the faiss library.

Please ensure that faiss is installed by running

pip install faiss-cpu
# or
pip install faiss-gpu

depending on whether to plan to use GPU-processing for \(k\)-NN search.

Parameters:
  • index_factory (str, optional) – The name of the index factory to use, e.g., "IndexFlatL2" or "IndexFlatIP". See here for more information.

  • emb (torch.Tensor, optional) – The data points to add. (default: None)

  • reserve (int, optional) – The number of elements to reserve memory for before re-allocating (GPU-only). (default: None)

property numel: int

The number of data points to search in.

add(emb: Tensor)[source]

Adds new data points to the KNNIndex to search in.

Parameters:

emb (torch.Tensor) – The data points to add.

search(emb: Tensor, k: int, exclude_links: Optional[Tensor] = None) KNNOutput[source]

Search for the \(k\) nearest neighbors of the given data points. Returns the distance/similarity score of the nearest neighbors and their indices.

Parameters:
  • emb (torch.Tensor) – The data points to add.

  • k (int) – The number of nearest neighbors to return.

  • exclude_links (torch.Tensor) – The links to exclude from searching. Needs to be a COO tensor of shape [2, num_links], where exclude_links[0] refers to indices in emb, and exclude_links[1] refers to the data points in the KNNIndex. (default: None)

get_emb() Tensor[source]

Returns the data points stored in the KNNIndex.