Source code for torch_geometric.nn.pool.knn

import warnings
from typing import NamedTuple, Optional

import torch
from torch import Tensor


class KNNOutput(NamedTuple):
    score: Tensor
    index: Tensor


[docs]class KNNIndex: r"""A base class to perform fast :math:`k`-nearest neighbor search (:math:`k`-NN) via the :obj:`faiss` library. Please ensure that :obj:`faiss` is installed by running .. code-block:: bash pip install faiss-cpu # or pip install faiss-gpu depending on whether to plan to use GPU-processing for :math:`k`-NN search. Args: index_factory (str): The name of the index factory to use, *e.g.*, :obj:`"IndexFlatL2"` or :obj:`"IndexFlatIP"`. See `here <https://github.com/facebookresearch/faiss/wiki/ The-index-factory>`_ for more information. emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) """ def __init__(self, index_factory: str, emb: Optional[Tensor] = None): warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*') import faiss self.numel = 0 self.index_factory = index_factory self.index: Optional[faiss.Index] = None if emb is not None: self.add(emb) def _create_index(self, channels: int): import faiss return faiss.index_factory(channels, self.index_factory)
[docs] def add(self, emb: Tensor): r"""Adds new data points to the :class:`KNNIndex` to search in. Args: emb (torch.Tensor): The data points to add. """ import faiss import faiss.contrib.torch_utils if emb.dim() != 2: raise ValueError(f"'emb' needs to be two-dimensional " f"(got {emb.dim()} dimensions)") if self.index is None: self.index = self._create_index(emb.size(1)) if emb.device != torch.device('cpu'): self.index = faiss.index_cpu_to_gpu( faiss.StandardGpuResources(), emb.device.index, self.index, ) self.numel += emb.size(0) self.index.add(emb.detach())
[docs] def search(self, emb: Tensor, k: int) -> KNNOutput: r"""Search for the :math:`k` nearest neighbors of the given data points. Returns the distance/similarity score of the nearest neighbors and their indices. Args: emb (torch.Tensor): The data points to add. k (int): The number of nearest neighbors to return. """ if self.index is None: raise RuntimeError(f"'{self.__class__.__name__}' is not yet " "initialized. Please call `add(...)` first.") if emb.dim() != 2: raise ValueError(f"'emb' needs to be two-dimensional " f"(got {emb.dim()} dimensions)") return KNNOutput(*self.index.search(emb.detach(), k))
[docs] def get_emb(self) -> Tensor: r"""Returns the data points stored in the :class:`KNNIndex`.""" if self.index is None: raise RuntimeError(f"'{self.__class__.__name__}' is not yet " "initialized. Please call `add(...)` first.") return self.index.reconstruct_n(0, self.numel)
[docs]class L2KNNIndex(KNNIndex): r"""Performs fast :math:`k`-nearest neighbor search (:math:`k`-NN) based on the :math:`L_2` metric via the :obj:`faiss` library. Args: emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) """ def __init__(self, emb: Optional[Tensor] = None): super().__init__(index_factory=None, emb=emb) def _create_index(self, channels: int): import faiss return faiss.IndexFlatL2(channels)
[docs]class MIPSKNNIndex(KNNIndex): r"""Performs fast :math:`k`-nearest neighbor search (:math:`k`-NN) based on the maximum inner product via the :obj:`faiss` library. Args: emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) """ def __init__(self, emb: Optional[Tensor] = None): super().__init__(index_factory=None, emb=emb) def _create_index(self, channels: int): import faiss return faiss.IndexFlatIP(channels)