torch_geometric.nn.pool.KNNIndex
- class KNNIndex(index_factory: Optional[str] = None, emb: Optional[Tensor] = None, reserve: Optional[int] = None)[source]
Bases:
object
A base class to perform fast \(k\)-nearest neighbor search (\(k\)-NN) via the
faiss
library.Please ensure that
faiss
is installed by runningpip install faiss-cpu # or pip install faiss-gpu
depending on whether to plan to use GPU-processing for \(k\)-NN search.
- Parameters:
index_factory (str, optional) – The name of the index factory to use, e.g.,
"IndexFlatL2"
or"IndexFlatIP"
. See here for more information.emb (torch.Tensor, optional) – The data points to add. (default:
None
)reserve (int, optional) – The number of elements to reserve memory for before re-allocating (GPU-only). (default:
None
)
- add(emb: Tensor)[source]
Adds new data points to the
KNNIndex
to search in.- Parameters:
emb (torch.Tensor) – The data points to add.
- search(emb: Tensor, k: int, exclude_links: Optional[Tensor] = None) KNNOutput [source]
Search for the \(k\) nearest neighbors of the given data points. Returns the distance/similarity score of the nearest neighbors and their indices.
- Parameters:
emb (torch.Tensor) – The data points to add.
k (int) – The number of nearest neighbors to return.
exclude_links (torch.Tensor) – The links to exclude from searching. Needs to be a COO tensor of shape
[2, num_links]
, whereexclude_links[0]
refers to indices inemb
, andexclude_links[1]
refers to the data points in theKNNIndex
. (default:None
)