Source code for torch_geometric.utils.negative_sampling

import random
from typing import Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor

from torch_geometric.utils import coalesce, degree, remove_self_loops

from .num_nodes import maybe_num_nodes


[docs]def negative_sampling(edge_index: Tensor, num_nodes: Optional[Union[int, Tuple[int, int]]] = None, num_neg_samples: Optional[int] = None, method: str = "sparse", force_undirected: bool = False) -> Tensor: r"""Samples random negative edges of a graph given by :attr:`edge_index`. Args: edge_index (LongTensor): The edge indices. num_nodes (int or Tuple[int, int], optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. If given as a tuple, then :obj:`edge_index` is interpreted as a bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`. (default: :obj:`None`) num_neg_samples (int, optional): The (approximate) number of negative samples to return. If set to :obj:`None`, will try to return a negative edge for every positive edge. (default: :obj:`None`) method (string, optional): The method to use for negative sampling, *i.e.*, :obj:`"sparse"` or :obj:`"dense"`. This is a memory/runtime trade-off. :obj:`"sparse"` will work on any graph of any size, while :obj:`"dense"` can perform faster true-negative checks. (default: :obj:`"sparse"`) force_undirected (bool, optional): If set to :obj:`True`, sampled negative edges will be undirected. (default: :obj:`False`) :rtype: LongTensor """ assert method in ['sparse', 'dense'] size = num_nodes bipartite = isinstance(size, (tuple, list)) size = maybe_num_nodes(edge_index) if size is None else size size = (size, size) if not bipartite else size force_undirected = False if bipartite else force_undirected idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected) if idx.numel() >= population: return edge_index.new_empty((2, 0)) if num_neg_samples is None: num_neg_samples = edge_index.size(1) if force_undirected: num_neg_samples = num_neg_samples // 2 prob = 1. - idx.numel() / population # Probability to sample a negative. sample_size = int(1.1 * num_neg_samples / prob) # (Over)-sample size. neg_idx = None if method == 'dense': # The dense version creates a mask of shape `population` to check for # invalid samples. mask = idx.new_ones(population, dtype=torch.bool) mask[idx] = False for _ in range(3): # Number of tries to sample negative indices. rnd = sample(population, sample_size, idx.device) rnd = rnd[mask[rnd]] # Filter true negatives. neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd]) if neg_idx.numel() >= num_neg_samples: neg_idx = neg_idx[:num_neg_samples] break mask[neg_idx] = False else: # 'sparse' # The sparse version checks for invalid samples via `np.isin`. idx = idx.to('cpu') for _ in range(3): # Number of tries to sample negative indices. rnd = sample(population, sample_size, device='cpu') mask = np.isin(rnd, idx) if neg_idx is not None: mask |= np.isin(rnd, neg_idx.to('cpu')) mask = torch.from_numpy(mask).to(torch.bool) rnd = rnd[~mask].to(edge_index.device) neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd]) if neg_idx.numel() >= num_neg_samples: neg_idx = neg_idx[:num_neg_samples] break return vector_to_edge_index(neg_idx, size, bipartite, force_undirected)
[docs]def batched_negative_sampling( edge_index: Tensor, batch: Union[Tensor, Tuple[Tensor, Tensor]], num_neg_samples: Optional[int] = None, method: str = "sparse", force_undirected: bool = False, ) -> Tensor: r"""Samples random negative edges of multiple graphs given by :attr:`edge_index` and :attr:`batch`. Args: edge_index (LongTensor): The edge indices. batch (LongTensor or Tuple[LongTensor, LongTensor]): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. If given as a tuple, then :obj:`edge_index` is interpreted as a bipartite graph connecting two different node types. num_neg_samples (int, optional): The number of negative samples to return. If set to :obj:`None`, will try to return a negative edge for every positive edge. (default: :obj:`None`) method (string, optional): The method to use for negative sampling, *i.e.*, :obj:`"sparse"` or :obj:`"dense"`. This is a memory/runtime trade-off. :obj:`"sparse"` will work on any graph of any size, while :obj:`"dense"` can perform faster true-negative checks. (default: :obj:`"sparse"`) force_undirected (bool, optional): If set to :obj:`True`, sampled negative edges will be undirected. (default: :obj:`False`) :rtype: LongTensor """ if isinstance(batch, Tensor): src_batch, dst_batch = batch, batch else: src_batch, dst_batch = batch[0], batch[1] split = degree(src_batch[edge_index[0]], dtype=torch.long).tolist() edge_indices = torch.split(edge_index, split, dim=1) num_src = degree(src_batch, dtype=torch.long) cum_src = torch.cat([src_batch.new_zeros(1), num_src.cumsum(0)[:-1]]) if isinstance(batch, Tensor): num_nodes = num_src.tolist() cumsum = cum_src else: num_dst = degree(dst_batch, dtype=torch.long) cum_dst = torch.cat([dst_batch.new_zeros(1), num_dst.cumsum(0)[:-1]]) num_nodes = torch.stack([num_src, num_dst], dim=1).tolist() cumsum = torch.stack([cum_src, cum_dst], dim=1).unsqueeze(-1) neg_edge_indices = [] for i, edge_index in enumerate(edge_indices): edge_index = edge_index - cumsum[i] neg_edge_index = negative_sampling(edge_index, num_nodes[i], num_neg_samples, method, force_undirected) neg_edge_index += cumsum[i] neg_edge_indices.append(neg_edge_index) return torch.cat(neg_edge_indices, dim=1)
[docs]def structured_negative_sampling(edge_index, num_nodes: Optional[int] = None, contains_neg_self_loops: bool = True): r"""Samples a negative edge :obj:`(i,k)` for every positive edge :obj:`(i,j)` in the graph given by :attr:`edge_index`, and returns it as a tuple of the form :obj:`(i,j,k)`. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) contains_neg_self_loops (bool, optional): If set to :obj:`False`, sampled negative edges will not contain self loops. (default: :obj:`True`) :rtype: (LongTensor, LongTensor, LongTensor) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index.cpu() pos_idx = row * num_nodes + col if not contains_neg_self_loops: loop_idx = torch.arange(num_nodes) * (num_nodes + 1) pos_idx = torch.cat([pos_idx, loop_idx], dim=0) rand = torch.randint(num_nodes, (row.size(0), ), dtype=torch.long) neg_idx = row * num_nodes + rand mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool) rest = mask.nonzero(as_tuple=False).view(-1) while rest.numel() > 0: # pragma: no cover tmp = torch.randint(num_nodes, (rest.size(0), ), dtype=torch.long) rand[rest] = tmp neg_idx = row[rest] * num_nodes + tmp mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool) rest = rest[mask] return edge_index[0], edge_index[1], rand.to(edge_index.device)
[docs]def structured_negative_sampling_feasible( edge_index: Tensor, num_nodes: Optional[int] = None, contains_neg_self_loops: bool = True) -> bool: r"""Returns :obj:`True` if :meth:`~torch_geometric.utils.structured_negative_sampling` is feasible on the graph given by :obj:`edge_index`. :obj:`~torch_geometric.utils.structured_negative_sampling` is infeasible if atleast one node is connected to all other nodes. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) contains_neg_self_loops (bool, optional): If set to :obj:`False`, sampled negative edges will not contain self loops. (default: :obj:`True`) :rtype: bool """ num_nodes = maybe_num_nodes(edge_index, num_nodes) max_num_neighbors = num_nodes edge_index = coalesce(edge_index, num_nodes=num_nodes) if not contains_neg_self_loops: edge_index, _ = remove_self_loops(edge_index) max_num_neighbors -= 1 # Reduce number of valid neighbors deg = degree(edge_index[0], num_nodes) # True if there exists no node that is connected to all other nodes. return bool(torch.all(deg < max_num_neighbors))
############################################################################### def sample(population: int, k: int, device=None) -> Tensor: if population <= k: return torch.arange(population, device=device) else: return torch.tensor(random.sample(range(population), k), device=device) def edge_index_to_vector( edge_index: Tensor, size: Tuple[int, int], bipartite: bool, force_undirected: bool = False, ) -> Tuple[Tensor, int]: row, col = edge_index if bipartite: # No need to account for self-loops. idx = (row * size[1]).add_(col) population = size[0] * size[1] return idx, population elif force_undirected: assert size[0] == size[1] num_nodes = size[0] # We only operate on the upper triangular matrix: mask = row < col row, col = row[mask], col[mask] offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row] idx = row.mul_(num_nodes).add_(col).sub_(offset) population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes return idx, population else: assert size[0] == size[1] num_nodes = size[0] # We remove self-loops as we do not want to take them into account # when sampling negative values. mask = row != col row, col = row[mask], col[mask] col[row < col] -= 1 idx = row.mul_(num_nodes - 1).add_(col) population = num_nodes * num_nodes - num_nodes return idx, population def vector_to_edge_index(idx: Tensor, size: Tuple[int, int], bipartite: bool, force_undirected: bool = False) -> Tensor: if bipartite: # No need to account for self-loops. row = idx.div(size[1], rounding_mode='floor') col = idx % size[1] return torch.stack([row, col], dim=0) elif force_undirected: assert size[0] == size[1] num_nodes = size[0] offset = torch.arange(1, num_nodes, device=idx.device).cumsum(0) end = torch.arange(num_nodes, num_nodes * num_nodes, num_nodes, device=idx.device) row = torch.bucketize(idx, end.sub_(offset), right=True) col = offset[row].add_(idx) % num_nodes return torch.stack([torch.cat([row, col]), torch.cat([col, row])], 0) else: assert size[0] == size[1] num_nodes = size[0] row = idx.div(num_nodes - 1, rounding_mode='floor') col = idx % (num_nodes - 1) col[row <= col] += 1 return torch.stack([row, col], dim=0)