Source code for torch_geometric.nn.pool.knn

import warnings
from typing import NamedTuple, Optional

import torch
from torch import Tensor

from torch_geometric.utils import cumsum, degree, to_dense_batch


class KNNOutput(NamedTuple):
    score: Tensor
    index: Tensor


[docs]class KNNIndex: r"""A base class to perform fast :math:`k`-nearest neighbor search (:math:`k`-NN) via the :obj:`faiss` library. Please ensure that :obj:`faiss` is installed by running .. code-block:: bash pip install faiss-cpu # or pip install faiss-gpu depending on whether to plan to use GPU-processing for :math:`k`-NN search. Args: index_factory (str): The name of the index factory to use, *e.g.*, :obj:`"IndexFlatL2"` or :obj:`"IndexFlatIP"`. See `here <https://github.com/facebookresearch/faiss/wiki/ The-index-factory>`_ for more information. emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) """ def __init__(self, index_factory: str, emb: Optional[Tensor] = None): warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*') import faiss self.numel = 0 self.index_factory = index_factory self.index: Optional[faiss.Index] = None if emb is not None: self.add(emb) def _create_index(self, channels: int): import faiss return faiss.index_factory(channels, self.index_factory)
[docs] def add(self, emb: Tensor): r"""Adds new data points to the :class:`KNNIndex` to search in. Args: emb (torch.Tensor): The data points to add. """ import faiss import faiss.contrib.torch_utils if emb.dim() != 2: raise ValueError(f"'emb' needs to be two-dimensional " f"(got {emb.dim()} dimensions)") if self.index is None: self.index = self._create_index(emb.size(1)) if emb.device != torch.device('cpu'): self.index = faiss.index_cpu_to_gpu( faiss.StandardGpuResources(), emb.device.index, self.index, ) self.numel += emb.size(0) self.index.add(emb.detach())
[docs] def search( self, emb: Tensor, k: int, exclude_links: Optional[Tensor] = None, ) -> KNNOutput: r"""Search for the :math:`k` nearest neighbors of the given data points. Returns the distance/similarity score of the nearest neighbors and their indices. Args: emb (torch.Tensor): The data points to add. k (int): The number of nearest neighbors to return. exclude_links (torch.Tensor): The links to exclude from searching. Needs to be a COO tensor of shape :obj:`[2, num_links]`, where :obj:`exclude_links[0]` refers to indices in :obj:`emb`, and :obj:`exclude_links[1]` refers to the data points in the :class:`KNNIndex`. (default: :obj:`None`) """ if self.index is None: raise RuntimeError(f"'{self.__class__.__name__}' is not yet " "initialized. Please call `add(...)` first.") if emb.dim() != 2: raise ValueError(f"'emb' needs to be two-dimensional " f"(got {emb.dim()} dimensions)") query_k = k if exclude_links is not None: deg = degree(exclude_links[0], num_nodes=emb.size(0)).max() query_k = k + int(deg.max() if deg.numel() > 0 else 0) query_k = min(query_k, self.numel) if k > 2048: # `faiss` supports up-to `k=2048`: warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 " f"(got {k}). This may cause some relevant items to " f"not be retrieved.") elif query_k > 2048: warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 " f"(got {k} which got extended to {query_k} due to " f"the exclusion of existing links). This may cause " f"some relevant items to not be retrieved.") query_k = 2048 score, index = self.index.search(emb.detach(), query_k) if exclude_links is not None: # Drop indices to exclude by converting to flat vector: flat_exclude = self.numel * exclude_links[0] + exclude_links[1] offset = torch.arange( start=0, end=self.numel * index.size(0), step=self.numel, device=index.device, ).view(-1, 1) flat_index = (index + offset).view(-1) notin = torch.isin(flat_index, flat_exclude).logical_not_() score = score.view(-1)[notin] index = index.view(-1)[notin] # Only maintain top-k scores: count = notin.view(-1, query_k).sum(dim=1) cum_count = cumsum(count) batch = torch.arange(count.numel(), device=count.device) batch = batch.repeat_interleave(count, output_size=cum_count[-1]) batch_arange = torch.arange(count.sum(), device=count.device) batch_arange = batch_arange - cum_count[batch] mask = batch_arange < k score = score[mask] index = index[mask] if count.min() < k: # Fill with dummy scores: batch = batch[mask] score, _ = to_dense_batch( score, batch, fill_value=float('-inf'), max_num_nodes=k, batch_size=emb.size(0), ) index, _ = to_dense_batch( index, batch, fill_value=-1, max_num_nodes=k, batch_size=emb.size(0), ) score = score.view(-1, k) index = index.view(-1, k) return KNNOutput(score, index)
[docs] def get_emb(self) -> Tensor: r"""Returns the data points stored in the :class:`KNNIndex`.""" if self.index is None: raise RuntimeError(f"'{self.__class__.__name__}' is not yet " "initialized. Please call `add(...)` first.") return self.index.reconstruct_n(0, self.numel)
[docs]class L2KNNIndex(KNNIndex): r"""Performs fast :math:`k`-nearest neighbor search (:math:`k`-NN) based on the :math:`L_2` metric via the :obj:`faiss` library. Args: emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) """ def __init__(self, emb: Optional[Tensor] = None): super().__init__(index_factory=None, emb=emb) def _create_index(self, channels: int): import faiss return faiss.IndexFlatL2(channels)
[docs]class MIPSKNNIndex(KNNIndex): r"""Performs fast :math:`k`-nearest neighbor search (:math:`k`-NN) based on the maximum inner product via the :obj:`faiss` library. Args: emb (torch.Tensor, optional): The data points to add. (default: :obj:`None`) """ def __init__(self, emb: Optional[Tensor] = None): super().__init__(index_factory=None, emb=emb) def _create_index(self, channels: int): import faiss return faiss.IndexFlatIP(channels)