import random
from typing import Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from torch_geometric.utils import coalesce, degree, remove_self_loops
from .num_nodes import maybe_num_nodes
[docs]def negative_sampling(edge_index: Tensor,
num_nodes: Optional[Union[int, Tuple[int, int]]] = None,
num_neg_samples: Optional[int] = None,
method: str = "sparse",
force_undirected: bool = False) -> Tensor:
r"""Samples random negative edges of a graph given by :attr:`edge_index`.
Args:
edge_index (LongTensor): The edge indices.
num_nodes (int or Tuple[int, int], optional): The number of nodes,
*i.e.* :obj:`max_val + 1` of :attr:`edge_index`.
If given as a tuple, then :obj:`edge_index` is interpreted as a
bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`.
(default: :obj:`None`)
num_neg_samples (int, optional): The (approximate) number of negative
samples to return.
If set to :obj:`None`, will try to return a negative edge for every
positive edge. (default: :obj:`None`)
method (string, optional): The method to use for negative sampling,
*i.e.*, :obj:`"sparse"` or :obj:`"dense"`.
This is a memory/runtime trade-off.
:obj:`"sparse"` will work on any graph of any size, while
:obj:`"dense"` can perform faster true-negative checks.
(default: :obj:`"sparse"`)
force_undirected (bool, optional): If set to :obj:`True`, sampled
negative edges will be undirected. (default: :obj:`False`)
:rtype: LongTensor
"""
assert method in ['sparse', 'dense']
size = num_nodes
bipartite = isinstance(size, (tuple, list))
size = maybe_num_nodes(edge_index) if size is None else size
size = (size, size) if not bipartite else size
force_undirected = False if bipartite else force_undirected
idx, population = edge_index_to_vector(edge_index, size, bipartite,
force_undirected)
if idx.numel() >= population:
return edge_index.new_empty((2, 0))
if num_neg_samples is None:
num_neg_samples = edge_index.size(1)
if force_undirected:
num_neg_samples = num_neg_samples // 2
prob = 1. - idx.numel() / population # Probability to sample a negative.
sample_size = int(1.1 * num_neg_samples / prob) # (Over)-sample size.
neg_idx = None
if method == 'dense':
# The dense version creates a mask of shape `population` to check for
# invalid samples.
mask = idx.new_ones(population, dtype=torch.bool)
mask[idx] = False
for _ in range(3): # Number of tries to sample negative indices.
rnd = sample(population, sample_size, idx.device)
rnd = rnd[mask[rnd]] # Filter true negatives.
neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
if neg_idx.numel() >= num_neg_samples:
neg_idx = neg_idx[:num_neg_samples]
break
mask[neg_idx] = False
else: # 'sparse'
# The sparse version checks for invalid samples via `np.isin`.
idx = idx.to('cpu')
for _ in range(3): # Number of tries to sample negative indices.
rnd = sample(population, sample_size, device='cpu')
mask = np.isin(rnd, idx)
if neg_idx is not None:
mask |= np.isin(rnd, neg_idx.to('cpu'))
mask = torch.from_numpy(mask).to(torch.bool)
rnd = rnd[~mask].to(edge_index.device)
neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
if neg_idx.numel() >= num_neg_samples:
neg_idx = neg_idx[:num_neg_samples]
break
return vector_to_edge_index(neg_idx, size, bipartite, force_undirected)
[docs]def batched_negative_sampling(
edge_index: Tensor,
batch: Union[Tensor, Tuple[Tensor, Tensor]],
num_neg_samples: Optional[int] = None,
method: str = "sparse",
force_undirected: bool = False,
) -> Tensor:
r"""Samples random negative edges of multiple graphs given by
:attr:`edge_index` and :attr:`batch`.
Args:
edge_index (LongTensor): The edge indices.
batch (LongTensor or Tuple[LongTensor, LongTensor]): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example.
If given as a tuple, then :obj:`edge_index` is interpreted as a
bipartite graph connecting two different node types.
num_neg_samples (int, optional): The number of negative samples to
return. If set to :obj:`None`, will try to return a negative edge
for every positive edge. (default: :obj:`None`)
method (string, optional): The method to use for negative sampling,
*i.e.*, :obj:`"sparse"` or :obj:`"dense"`.
This is a memory/runtime trade-off.
:obj:`"sparse"` will work on any graph of any size, while
:obj:`"dense"` can perform faster true-negative checks.
(default: :obj:`"sparse"`)
force_undirected (bool, optional): If set to :obj:`True`, sampled
negative edges will be undirected. (default: :obj:`False`)
:rtype: LongTensor
"""
if isinstance(batch, Tensor):
src_batch, dst_batch = batch, batch
else:
src_batch, dst_batch = batch[0], batch[1]
split = degree(src_batch[edge_index[0]], dtype=torch.long).tolist()
edge_indices = torch.split(edge_index, split, dim=1)
num_src = degree(src_batch, dtype=torch.long)
cum_src = torch.cat([src_batch.new_zeros(1), num_src.cumsum(0)[:-1]])
if isinstance(batch, Tensor):
num_nodes = num_src.tolist()
cumsum = cum_src
else:
num_dst = degree(dst_batch, dtype=torch.long)
cum_dst = torch.cat([dst_batch.new_zeros(1), num_dst.cumsum(0)[:-1]])
num_nodes = torch.stack([num_src, num_dst], dim=1).tolist()
cumsum = torch.stack([cum_src, cum_dst], dim=1).unsqueeze(-1)
neg_edge_indices = []
for i, edge_index in enumerate(edge_indices):
edge_index = edge_index - cumsum[i]
neg_edge_index = negative_sampling(edge_index, num_nodes[i],
num_neg_samples, method,
force_undirected)
neg_edge_index += cumsum[i]
neg_edge_indices.append(neg_edge_index)
return torch.cat(neg_edge_indices, dim=1)
[docs]def structured_negative_sampling(edge_index, num_nodes: Optional[int] = None,
contains_neg_self_loops: bool = True):
r"""Samples a negative edge :obj:`(i,k)` for every positive edge
:obj:`(i,j)` in the graph given by :attr:`edge_index`, and returns it as a
tuple of the form :obj:`(i,j,k)`.
Args:
edge_index (LongTensor): The edge indices.
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
contains_neg_self_loops (bool, optional): If set to
:obj:`False`, sampled negative edges will not contain self loops.
(default: :obj:`True`)
:rtype: (LongTensor, LongTensor, LongTensor)
"""
num_nodes = maybe_num_nodes(edge_index, num_nodes)
row, col = edge_index.cpu()
pos_idx = row * num_nodes + col
if not contains_neg_self_loops:
loop_idx = torch.arange(num_nodes) * (num_nodes + 1)
pos_idx = torch.cat([pos_idx, loop_idx], dim=0)
rand = torch.randint(num_nodes, (row.size(0), ), dtype=torch.long)
neg_idx = row * num_nodes + rand
mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool)
rest = mask.nonzero(as_tuple=False).view(-1)
while rest.numel() > 0: # pragma: no cover
tmp = torch.randint(num_nodes, (rest.size(0), ), dtype=torch.long)
rand[rest] = tmp
neg_idx = row[rest] * num_nodes + tmp
mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool)
rest = rest[mask]
return edge_index[0], edge_index[1], rand.to(edge_index.device)
[docs]def structured_negative_sampling_feasible(
edge_index: Tensor, num_nodes: Optional[int] = None,
contains_neg_self_loops: bool = True) -> bool:
r"""Returns :obj:`True` if
:meth:`~torch_geometric.utils.structured_negative_sampling` is feasible
on the graph given by :obj:`edge_index`.
:obj:`~torch_geometric.utils.structured_negative_sampling` is infeasible
if atleast one node is connected to all other nodes.
Args:
edge_index (LongTensor): The edge indices.
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
contains_neg_self_loops (bool, optional): If set to
:obj:`False`, sampled negative edges will not contain self loops.
(default: :obj:`True`)
:rtype: bool
"""
num_nodes = maybe_num_nodes(edge_index, num_nodes)
max_num_neighbors = num_nodes
edge_index = coalesce(edge_index, num_nodes=num_nodes)
if not contains_neg_self_loops:
edge_index, _ = remove_self_loops(edge_index)
max_num_neighbors -= 1 # Reduce number of valid neighbors
deg = degree(edge_index[0], num_nodes)
# True if there exists no node that is connected to all other nodes.
return bool(torch.all(deg < max_num_neighbors))
###############################################################################
def sample(population: int, k: int, device=None) -> Tensor:
if population <= k:
return torch.arange(population, device=device)
else:
return torch.tensor(random.sample(range(population), k), device=device)
def edge_index_to_vector(
edge_index: Tensor,
size: Tuple[int, int],
bipartite: bool,
force_undirected: bool = False,
) -> Tuple[Tensor, int]:
row, col = edge_index
if bipartite: # No need to account for self-loops.
idx = (row * size[1]).add_(col)
population = size[0] * size[1]
return idx, population
elif force_undirected:
assert size[0] == size[1]
num_nodes = size[0]
# We only operate on the upper triangular matrix:
mask = row < col
row, col = row[mask], col[mask]
offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]
idx = row.mul_(num_nodes).add_(col).sub_(offset)
population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes
return idx, population
else:
assert size[0] == size[1]
num_nodes = size[0]
# We remove self-loops as we do not want to take them into account
# when sampling negative values.
mask = row != col
row, col = row[mask], col[mask]
col[row < col] -= 1
idx = row.mul_(num_nodes - 1).add_(col)
population = num_nodes * num_nodes - num_nodes
return idx, population
def vector_to_edge_index(idx: Tensor, size: Tuple[int, int], bipartite: bool,
force_undirected: bool = False) -> Tensor:
if bipartite: # No need to account for self-loops.
row = idx.div(size[1], rounding_mode='floor')
col = idx % size[1]
return torch.stack([row, col], dim=0)
elif force_undirected:
assert size[0] == size[1]
num_nodes = size[0]
offset = torch.arange(1, num_nodes, device=idx.device).cumsum(0)
end = torch.arange(num_nodes, num_nodes * num_nodes, num_nodes,
device=idx.device)
row = torch.bucketize(idx, end.sub_(offset), right=True)
col = offset[row].add_(idx) % num_nodes
return torch.stack([torch.cat([row, col]), torch.cat([col, row])], 0)
else:
assert size[0] == size[1]
num_nodes = size[0]
row = idx.div(num_nodes - 1, rounding_mode='floor')
col = idx % (num_nodes - 1)
col[row <= col] += 1
return torch.stack([row, col], dim=0)