torch_geometric.nn.pool.KNNIndex

class KNNIndex(index_factory: str, emb: Optional[Tensor] = None)[source]

Bases: object

A base class to perform fast \(k\)-nearest neighbor search (\(k\)-NN) via the faiss library.

Please ensure that faiss is installed by running

pip install faiss-cpu
# or
pip install faiss-gpu

depending on whether to plan to use GPU-processing for \(k\)-NN search.

Parameters:
  • index_factory (str) – The name of the index factory to use, e.g., "IndexFlatL2" or "IndexFlatIP". See here for more information.

  • emb (torch.Tensor, optional) – The data points to add. (default: None)

add(emb: Tensor)[source]

Adds new data points to the KNNIndex to search in.

Parameters:

emb (torch.Tensor) – The data points to add.

search(emb: Tensor, k: int) KNNOutput[source]

Search for the \(k\) nearest neighbors of the given data points. Returns the distance/similarity score of the nearest neighbors and their indices.

Parameters:
  • emb (torch.Tensor) – The data points to add.

  • k (int) – The number of nearest neighbors to return.

get_emb() Tensor[source]

Returns the data points stored in the KNNIndex.