Source code for torch_geometric.utils.negative_sampling

import random

import torch
import numpy as np
from torch_geometric.utils import degree, to_undirected

from .num_nodes import maybe_num_nodes


def sample(high: int, size: int, device=None):
    size = min(high, size)
    return torch.tensor(random.sample(range(high), size), device=device)


[docs]def negative_sampling(edge_index, num_nodes=None, num_neg_samples=None, method="sparse", force_undirected=False): r"""Samples random negative edges of a graph given by :attr:`edge_index`. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) num_neg_samples (int, optional): The (approximate) number of negative samples to return. If set to :obj:`None`, will try to return a negative edge for every positive edge. (default: :obj:`None`) method (string, optional): The method to use for negative sampling, *i.e.*, :obj:`"sparse"` or :obj:`"dense"`. This is a memory/runtime trade-off. :obj:`"sparse"` will work on any graph of any size, while :obj:`"dense"` can perform faster true-negative checks. (default: :obj:`"sparse"`) force_undirected (bool, optional): If set to :obj:`True`, sampled negative edges will be undirected. (default: :obj:`False`) :rtype: LongTensor """ num_nodes = maybe_num_nodes(edge_index, num_nodes) num_neg_samples = num_neg_samples or edge_index.size(1) # Handle '|V|^2 - |E| < |E|'. size = num_nodes * num_nodes num_neg_samples = min(num_neg_samples, size - edge_index.size(1)) row, col = edge_index if force_undirected: num_neg_samples = num_neg_samples // 2 # Upper triangle indices: N + ... + 1 = N (N + 1) / 2 size = (num_nodes * (num_nodes + 1)) // 2 # Remove edges in the lower triangle matrix. mask = row <= col row, col = row[mask], col[mask] # idx = N * i + j - i * (i+1) / 2 idx = row * num_nodes + col - row * (row + 1) // 2 else: idx = row * num_nodes + col # Percentage of edges to oversample so that we are save to only sample once # (in most cases). alpha = 1 / (1 - 1.1 * (edge_index.size(1) / size)) if method == 'dense': mask = edge_index.new_ones(size, dtype=torch.bool) mask[idx] = False mask = mask.view(-1) perm = sample(size, int(alpha * num_neg_samples), device=edge_index.device) perm = perm[mask[perm]][:num_neg_samples] else: perm = sample(size, int(alpha * num_neg_samples)) mask = torch.from_numpy(np.isin(perm, idx.to('cpu'))).to(torch.bool) perm = perm[~mask][:num_neg_samples].to(edge_index.device) if force_undirected: # (-sqrt((2 * N + 1)^2 - 8 * perm) + 2 * N + 1) / 2 row = torch.floor((-torch.sqrt((2. * num_nodes + 1.)**2 - 8. * perm) + 2 * num_nodes + 1) / 2) col = perm - row * (2 * num_nodes - row - 1) // 2 neg_edge_index = torch.stack([row, col], dim=0).long() neg_edge_index = to_undirected(neg_edge_index) else: row = perm // num_nodes col = perm % num_nodes neg_edge_index = torch.stack([row, col], dim=0).long() return neg_edge_index
[docs]def structured_negative_sampling(edge_index, num_nodes=None): r"""Samples a negative edge :obj:`(i,k)` for every positive edge :obj:`(i,j)` in the graph given by :attr:`edge_index`, and returns it as a tuple of the form :obj:`(i,j,k)`. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (LongTensor, LongTensor, LongTensor) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) i, j = edge_index.to('cpu') idx_1 = i * num_nodes + j k = torch.randint(num_nodes, (i.size(0), ), dtype=torch.long) idx_2 = i * num_nodes + k mask = torch.from_numpy(np.isin(idx_2, idx_1)).to(torch.bool) rest = mask.nonzero(as_tuple=False).view(-1) while rest.numel() > 0: # pragma: no cover tmp = torch.randint(num_nodes, (rest.numel(), ), dtype=torch.long) idx_2 = i[rest] * num_nodes + tmp mask = torch.from_numpy(np.isin(idx_2, idx_1)).to(torch.bool) k[rest] = tmp rest = rest[mask.nonzero(as_tuple=False).view(-1)] return edge_index[0], edge_index[1], k.to(edge_index.device)
[docs]def batched_negative_sampling(edge_index, batch, num_neg_samples=None, method="sparse", force_undirected=False): r"""Samples random negative edges of multiple graphs given by :attr:`edge_index` and :attr:`batch`. Args: edge_index (LongTensor): The edge indices. batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. num_neg_samples (int, optional): The number of negative samples to return. If set to :obj:`None`, will try to return a negative edge for every positive edge. (default: :obj:`None`) method (string, optional): The method to use for negative sampling, *i.e.*, :obj:`"sparse"` or :obj:`"dense"`. This is a memory/runtime trade-off. :obj:`"sparse"` will work on any graph of any size, while :obj:`"dense"` can perform faster true-negative checks. (default: :obj:`"sparse"`) force_undirected (bool, optional): If set to :obj:`True`, sampled negative edges will be undirected. (default: :obj:`False`) :rtype: LongTensor """ split = degree(batch[edge_index[0]], dtype=torch.long).tolist() edge_indices = torch.split(edge_index, split, dim=1) num_nodes = degree(batch, dtype=torch.long) cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]]) neg_edge_indices = [] for edge_index, N, C in zip(edge_indices, num_nodes.tolist(), cum_nodes.tolist()): neg_edge_index = negative_sampling(edge_index - C, N, num_neg_samples, method, force_undirected) + C neg_edge_indices.append(neg_edge_index) return torch.cat(neg_edge_indices, dim=1)