Source code for torch_geometric.sampler.neighbor_sampler

import copy
import math
import sys
import warnings
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.data import (
    Data,
    FeatureStore,
    GraphStore,
    HeteroData,
    remote_backend_utils,
)
from torch_geometric.data.graph_store import EdgeLayout
from torch_geometric.sampler import (
    BaseSampler,
    EdgeSamplerInput,
    HeteroSamplerOutput,
    NegativeSampling,
    NodeSamplerInput,
    SamplerOutput,
)
from torch_geometric.sampler.base import DataType, NumNeighbors, SubgraphType
from torch_geometric.sampler.utils import remap_keys, to_csc, to_hetero_csc
from torch_geometric.typing import EdgeType, NodeType, OptTensor

NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]


[docs]class NeighborSampler(BaseSampler): r"""An implementation of an in-memory (heterogeneous) neighbor sampler used by :class:`~torch_geometric.loader.NeighborLoader`. """ def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: NumNeighborsType, subgraph_type: Union[SubgraphType, str] = 'directional', replace: bool = False, disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, weight_attr: Optional[str] = None, is_sorted: bool = False, share_memory: bool = False, # Deprecated: directed: bool = True, ): if not directed: subgraph_type = SubgraphType.induced warnings.warn(f"The usage of the 'directed' argument in " f"'{self.__class__.__name__}' is deprecated. Use " f"`subgraph_type='induced'` instead.") if (not torch_geometric.typing.WITH_PYG_LIB and sys.platform == 'linux' and subgraph_type != SubgraphType.induced): warnings.warn(f"Using '{self.__class__.__name__}' without a " f"'pyg-lib' installation is deprecated and will be " f"removed soon. Please install 'pyg-lib' for " f"accelerated neighborhood sampling") self.data_type = DataType.from_data(data) if self.data_type == DataType.homogeneous: self.num_nodes = data.num_nodes self.node_time: Optional[Tensor] = None self.edge_time: Optional[Tensor] = None if time_attr is not None: if data.is_node_attr(time_attr): self.node_time = data[time_attr] elif data.is_edge_attr(time_attr): self.edge_time = data[time_attr] else: raise ValueError( f"The time attribute '{time_attr}' is neither a " f"node-level or edge-level attribute") # Convert the graph data into CSC format for sampling: self.colptr, self.row, self.perm = to_csc( data, device='cpu', share_memory=share_memory, is_sorted=is_sorted, src_node_time=self.node_time, edge_time=self.edge_time) if self.edge_time is not None and self.perm is not None: self.edge_time = self.edge_time[self.perm] self.edge_weight: Optional[Tensor] = None if weight_attr is not None: self.edge_weight = data[weight_attr] if self.perm is not None: self.edge_weight = self.edge_weight[self.perm] elif self.data_type == DataType.heterogeneous: self.node_types, self.edge_types = data.metadata() self.num_nodes = {k: data[k].num_nodes for k in self.node_types} self.node_time: Optional[Dict[NodeType, Tensor]] = None self.edge_time: Optional[Dict[EdgeType, Tensor]] = None if time_attr is not None: is_node_level_time = is_edge_level_time = False for store in data.node_stores: if time_attr in store: is_node_level_time = True for store in data.edge_stores: if time_attr in store: is_edge_level_time = True if is_node_level_time and is_edge_level_time: raise ValueError( f"The time attribute '{time_attr}' holds both " f"node-level and edge-level information") if not is_node_level_time and not is_edge_level_time: raise ValueError( f"The time attribute '{time_attr}' is neither a " f"node-level or edge-level attribute") if is_node_level_time: self.node_time = data.collect(time_attr) else: self.edge_time = data.collect(time_attr) # Conversion to/from C++ string type: Since C++ cannot take # dictionaries with tuples as key as input, edge type triplets need # to be converted into single strings. self.to_rel_type = {k: '__'.join(k) for k in self.edge_types} self.to_edge_type = {v: k for k, v in self.to_rel_type.items()} # Convert the graph data into CSC format for sampling: colptr_dict, row_dict, self.perm = to_hetero_csc( data, device='cpu', share_memory=share_memory, is_sorted=is_sorted, node_time_dict=self.node_time, edge_time_dict=self.edge_time) self.row_dict = remap_keys(row_dict, self.to_rel_type) self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type) if self.edge_time is not None: for edge_type, edge_time in self.edge_time.items(): if self.perm.get(edge_type, None) is not None: edge_time = edge_time[self.perm[edge_type]] self.edge_time[edge_type] = edge_time self.edge_time = remap_keys(self.edge_time, self.to_rel_type) self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None if weight_attr is not None: self.edge_weight = data.collect(weight_attr) for edge_type, edge_weight in self.edge_weight.items(): if self.perm.get(edge_type, None) is not None: edge_weight = edge_weight[self.perm[edge_type]] self.edge_weight[edge_type] = edge_weight self.edge_weight = remap_keys(self.edge_weight, self.to_rel_type) else: # self.data_type == DataType.remote feature_store, graph_store = data # Obtain graph metadata: attrs = [attr for attr in feature_store.get_all_tensor_attrs()] edge_attrs = graph_store.get_all_edge_attrs() self.edge_types = list({attr.edge_type for attr in edge_attrs}) if weight_attr is not None: raise NotImplementedError( f"'weight_attr' argument not yet supported within " f"'{self.__class__.__name__}' for " f"'(FeatureStore, GraphStore)' inputs") if time_attr is not None: # If the `time_attr` is present, we expect that `GraphStore` # holds all edges sorted by destination, and within local # neighborhoods, node indices should be sorted by time. # TODO (matthias, manan) Find an alternative way to ensure. for edge_attr in edge_attrs: if edge_attr.layout == EdgeLayout.CSR: raise ValueError( "Temporal sampling requires that edges are stored " "in either COO or CSC layout") if not edge_attr.is_sorted: raise ValueError( "Temporal sampling requires that edges are " "sorted by destination, and by source time " "within local neighborhoods") # We obtain all features with `node_attr.name=time_attr`: time_attrs = [ copy.copy(attr) for attr in attrs if attr.attr_name == time_attr ] if not self.is_hetero: self.node_types = [None] self.num_nodes = max(edge_attrs[0].size) self.edge_weight: Optional[Tensor] = None self.node_time: Optional[Tensor] = None self.edge_time: Optional[Tensor] = None if time_attr is not None: if len(time_attrs) != 1: raise ValueError("Temporal sampling specified but did " "not find any temporal data") time_attrs[0].index = None # Reset index for full data. time_tensor = feature_store.get_tensor(time_attrs[0]) # Currently, we determine whether to use node-level or # edge-level temporal sampling based on the attribute name. if time_attr == 'time': self.node_time = time_tensor else: self.edge_time = time_tensor self.row, self.colptr, self.perm = graph_store.csc() else: node_types = [ attr.group_name for attr in attrs if isinstance(attr.group_name, str) ] self.node_types = list(set(node_types)) self.num_nodes = { node_type: remote_backend_utils.size(*data, node_type) for node_type in self.node_types } self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None self.node_time: Optional[Dict[NodeType, Tensor]] = None self.edge_time: Optional[Dict[EdgeType, Tensor]] = None if time_attr is not None: for attr in time_attrs: # Reset index for full data. attr.index = None time_tensors = feature_store.multi_get_tensor(time_attrs) time = { attr.group_name: time_tensor for attr, time_tensor in zip(time_attrs, time_tensors) } group_names = [attr.group_name for attr in time_attrs] if all([isinstance(g, str) for g in group_names]): self.node_time = time elif all([isinstance(g, tuple) for g in group_names]): self.edge_time = time else: raise ValueError( f"Found time attribute '{time_attr}' for both " f"node-level and edge-level types") # Conversion to/from C++ string type (see above): self.to_rel_type = {k: '__'.join(k) for k in self.edge_types} self.to_edge_type = {v: k for k, v in self.to_rel_type.items()} # Convert the graph data into CSC format for sampling: row_dict, colptr_dict, self.perm = graph_store.csc() self.row_dict = remap_keys(row_dict, self.to_rel_type) self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type) if (self.edge_time is not None and not torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE): raise ImportError("Edge-level temporal sampling requires a " "more recent 'pyg-lib' installation") if (self.edge_weight is not None and not torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE): raise ImportError("Weighted neighbor sampling requires " "'pyg-lib>=0.3.0'") self.num_neighbors = num_neighbors self.replace = replace self.subgraph_type = SubgraphType(subgraph_type) self.disjoint = disjoint self.temporal_strategy = temporal_strategy @property def num_neighbors(self) -> NumNeighbors: return self._num_neighbors @num_neighbors.setter def num_neighbors(self, num_neighbors: NumNeighborsType): if isinstance(num_neighbors, NumNeighbors): self._num_neighbors = num_neighbors else: self._num_neighbors = NumNeighbors(num_neighbors) @property def is_hetero(self) -> bool: if self.data_type == DataType.homogeneous: return False if self.data_type == DataType.heterogeneous: return True # self.data_type == DataType.remote return self.edge_types != [None] @property def is_temporal(self) -> bool: return self.node_time is not None or self.edge_time is not None @property def disjoint(self) -> bool: return self._disjoint or self.is_temporal @disjoint.setter def disjoint(self, disjoint: bool): self._disjoint = disjoint # Node-based sampling ##################################################### def sample_from_nodes( self, inputs: NodeSamplerInput, ) -> Union[SamplerOutput, HeteroSamplerOutput]: out = node_sample(inputs, self._sample) if self.subgraph_type == SubgraphType.bidirectional: out = out.to_bidirectional() return out # Edge-based sampling ##################################################### def sample_from_edges( self, inputs: EdgeSamplerInput, neg_sampling: Optional[NegativeSampling] = None, ) -> Union[SamplerOutput, HeteroSamplerOutput]: out = edge_sample(inputs, self._sample, self.num_nodes, self.disjoint, self.node_time, neg_sampling) if self.subgraph_type == SubgraphType.bidirectional: out = out.to_bidirectional() return out # Other Utilities ######################################################### @property def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]: return self.perm # Helper functions ######################################################## def _sample( self, seed: Union[Tensor, Dict[NodeType, Tensor]], seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None, **kwargs, ) -> Union[SamplerOutput, HeteroSamplerOutput]: r"""Implements neighbor sampling by calling either :obj:`pyg-lib` (if installed) or :obj:`torch-sparse` (if installed) sampling routines. """ if isinstance(seed, dict): # Heterogeneous sampling: # TODO Support induced subgraph sampling in `pyg-lib`. if (torch_geometric.typing.WITH_PYG_LIB and self.subgraph_type != SubgraphType.induced): # TODO (matthias) Ideally, `seed` inherits dtype from `colptr` colptrs = list(self.colptr_dict.values()) dtype = colptrs[0].dtype if len(colptrs) > 0 else torch.int64 seed = {k: v.to(dtype) for k, v in seed.items()} args = ( self.node_types, self.edge_types, self.colptr_dict, self.row_dict, seed, self.num_neighbors.get_mapped_values(self.edge_types), self.node_time, ) if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE: args += (self.edge_time, ) args += (seed_time, ) if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE: args += (self.edge_weight, ) args += ( True, # csc self.replace, self.subgraph_type != SubgraphType.induced, self.disjoint, self.temporal_strategy, # TODO (matthias) `return_edge_id` if edge features present True, # return_edge_id ) out = torch.ops.pyg.hetero_neighbor_sample(*args) row, col, node, edge, batch = out[:4] + (None, ) # `pyg-lib>0.1.0` returns sampled number of nodes/edges: num_sampled_nodes = num_sampled_edges = None if len(out) >= 6: num_sampled_nodes, num_sampled_edges = out[4:6] if self.disjoint: node = {k: v.t().contiguous() for k, v in node.items()} batch = {k: v[0] for k, v in node.items()} node = {k: v[1] for k, v in node.items()} elif torch_geometric.typing.WITH_TORCH_SPARSE: if self.disjoint: if self.subgraph_type == SubgraphType.induced: raise ValueError("'disjoint' sampling not supported " "for neighbor sampling with " "`subgraph_type='induced'`") else: raise ValueError("'disjoint' sampling not supported " "for neighbor sampling via " "'torch-sparse'. Please install " "'pyg-lib' for improved and " "optimized sampling routines.") out = torch.ops.torch_sparse.hetero_neighbor_sample( self.node_types, self.edge_types, self.colptr_dict, self.row_dict, seed, # seed_dict self.num_neighbors.get_mapped_values(self.edge_types), self.num_neighbors.num_hops, self.replace, self.subgraph_type != SubgraphType.induced, ) node, row, col, edge, batch = out + (None, ) num_sampled_nodes = num_sampled_edges = None else: raise ImportError(f"'{self.__class__.__name__}' requires " f"either 'pyg-lib' or 'torch-sparse'") if num_sampled_edges is not None: num_sampled_edges = remap_keys( num_sampled_edges, self.to_edge_type, ) return HeteroSamplerOutput( node=node, row=remap_keys(row, self.to_edge_type), col=remap_keys(col, self.to_edge_type), edge=remap_keys(edge, self.to_edge_type), batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, ) else: # Homogeneous sampling: # TODO Support induced subgraph sampling in `pyg-lib`. if (torch_geometric.typing.WITH_PYG_LIB and self.subgraph_type != SubgraphType.induced): args = ( self.colptr, self.row, # TODO (matthias) `seed` should inherit dtype from `colptr` seed.to(self.colptr.dtype), self.num_neighbors.get_mapped_values(), self.node_time, ) if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE: args += (self.edge_time, ) args += (seed_time, ) if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE: args += (self.edge_weight, ) args += ( True, # csc self.replace, self.subgraph_type != SubgraphType.induced, self.disjoint, self.temporal_strategy, # TODO (matthias) `return_edge_id` if edge features present True, # return_edge_id ) out = torch.ops.pyg.neighbor_sample(*args) row, col, node, edge, batch = out[:4] + (None, ) # `pyg-lib>0.1.0` returns sampled number of nodes/edges: num_sampled_nodes = num_sampled_edges = None if len(out) >= 6: num_sampled_nodes, num_sampled_edges = out[4:6] if self.disjoint: batch, node = node.t().contiguous() elif torch_geometric.typing.WITH_TORCH_SPARSE: if self.disjoint: raise ValueError("'disjoint' sampling not supported for " "neighbor sampling via 'torch-sparse'. " "Please install 'pyg-lib' for improved " "and optimized sampling routines.") out = torch.ops.torch_sparse.neighbor_sample( self.colptr, self.row, seed, # seed self.num_neighbors.get_mapped_values(), self.replace, self.subgraph_type != SubgraphType.induced, ) node, row, col, edge, batch = out + (None, ) num_sampled_nodes = num_sampled_edges = None else: raise ImportError(f"'{self.__class__.__name__}' requires " f"either 'pyg-lib' or 'torch-sparse'") return SamplerOutput( node=node, row=row, col=col, edge=edge, batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, )
# Sampling Utilities ########################################################## def node_sample( inputs: NodeSamplerInput, sample_fn: Callable, ) -> Union[SamplerOutput, HeteroSamplerOutput]: r"""Performs sampling from a :class:`NodeSamplerInput`, leveraging a sampling function that accepts a seed and (optionally) a seed time as input. Returns the output of this sampling procedure. """ if inputs.input_type is not None: # Heterogeneous sampling: seed = {inputs.input_type: inputs.node} seed_time = None if inputs.time is not None: seed_time = {inputs.input_type: inputs.time} else: # Homogeneous sampling: seed = inputs.node seed_time = inputs.time out = sample_fn(seed, seed_time) out.metadata = (inputs.input_id, inputs.time) return out def edge_sample( inputs: EdgeSamplerInput, sample_fn: Callable, num_nodes: Union[int, Dict[NodeType, int]], disjoint: bool, node_time: Optional[Union[Tensor, Dict[str, Tensor]]] = None, neg_sampling: Optional[NegativeSampling] = None, ) -> Union[SamplerOutput, HeteroSamplerOutput]: r"""Performs sampling from an edge sampler input, leveraging a sampling function of the same signature as `node_sample`. """ input_id = inputs.input_id src = inputs.row dst = inputs.col edge_label = inputs.label edge_label_time = inputs.time input_type = inputs.input_type src_time = dst_time = edge_label_time assert edge_label_time is None or disjoint assert isinstance(num_nodes, (dict, int)) if not isinstance(num_nodes, dict): num_src_nodes = num_dst_nodes = num_nodes else: num_src_nodes = num_nodes[input_type[0]] num_dst_nodes = num_nodes[input_type[-1]] num_pos = src.numel() num_neg = 0 # Negative Sampling ####################################################### if neg_sampling is not None: # When we are doing negative sampling, we append negative information # of nodes/edges to `src`, `dst`, `src_time`, `dst_time`. # Later on, we can easily reconstruct what belongs to positive and # negative examples by slicing via `num_pos`. num_neg = math.ceil(num_pos * neg_sampling.amount) if neg_sampling.is_binary(): # In the "binary" case, we randomly sample negative pairs of nodes. if isinstance(node_time, dict): src_node_time = node_time.get(input_type[0]) else: src_node_time = node_time src_neg = neg_sample(src, neg_sampling, num_src_nodes, src_time, src_node_time, endpoint='src') src = torch.cat([src, src_neg], dim=0) if isinstance(node_time, dict): dst_node_time = node_time.get(input_type[-1]) else: dst_node_time = node_time dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time, dst_node_time, endpoint='dst') dst = torch.cat([dst, dst_neg], dim=0) if edge_label is None: edge_label = torch.ones(num_pos) size = (num_neg, ) + edge_label.size()[1:] edge_neg_label = edge_label.new_zeros(size) edge_label = torch.cat([edge_label, edge_neg_label]) if edge_label_time is not None: src_time = dst_time = edge_label_time.repeat( 1 + math.ceil(neg_sampling.amount))[:num_pos + num_neg] elif neg_sampling.is_triplet(): # In the "triplet" case, we randomly sample negative destinations. if isinstance(node_time, dict): dst_node_time = node_time.get(input_type[-1]) else: dst_node_time = node_time dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time, dst_node_time, endpoint='dst') dst = torch.cat([dst, dst_neg], dim=0) assert edge_label is None if edge_label_time is not None: dst_time = edge_label_time.repeat(1 + neg_sampling.amount) # Heterogeneous Neighborhood Sampling ##################################### if input_type is not None: seed_time_dict = None if input_type[0] != input_type[-1]: # Two distinct node types: if not disjoint: src, inverse_src = src.unique(return_inverse=True) dst, inverse_dst = dst.unique(return_inverse=True) seed_dict = {input_type[0]: src, input_type[-1]: dst} if edge_label_time is not None: # Always disjoint. seed_time_dict = { input_type[0]: src_time, input_type[-1]: dst_time, } else: # Only a single node type: Merge both source and destination. seed = torch.cat([src, dst], dim=0) if not disjoint: seed, inverse_seed = seed.unique(return_inverse=True) seed_dict = {input_type[0]: seed} if edge_label_time is not None: # Always disjoint. seed_time_dict = { input_type[0]: torch.cat([src_time, dst_time], dim=0), } out = sample_fn(seed_dict, seed_time_dict) # Enhance `out` by label information ################################## if disjoint: for key, batch in out.batch.items(): out.batch[key] = batch % num_pos if neg_sampling is None or neg_sampling.is_binary(): if disjoint: if input_type[0] != input_type[-1]: edge_label_index = torch.arange(num_pos + num_neg) edge_label_index = edge_label_index.repeat(2).view(2, -1) else: edge_label_index = torch.arange(2 * (num_pos + num_neg)) edge_label_index = edge_label_index.view(2, -1) else: if input_type[0] != input_type[-1]: edge_label_index = torch.stack([ inverse_src, inverse_dst, ], dim=0) else: edge_label_index = inverse_seed.view(2, -1) out.metadata = (input_id, edge_label_index, edge_label, src_time) elif neg_sampling.is_triplet(): if disjoint: src_index = torch.arange(num_pos) if input_type[0] != input_type[-1]: dst_pos_index = torch.arange(num_pos) # `dst_neg_index` needs to be offset such that indices with # offset `num_pos` belong to the same triplet: dst_neg_index = torch.arange( num_pos, seed_dict[input_type[-1]].numel()) dst_neg_index = dst_neg_index.view(-1, num_pos).t() else: dst_pos_index = torch.arange(num_pos, 2 * num_pos) dst_neg_index = torch.arange( 2 * num_pos, seed_dict[input_type[-1]].numel()) dst_neg_index = dst_neg_index.view(-1, num_pos).t() else: if input_type[0] != input_type[-1]: src_index = inverse_src dst_pos_index = inverse_dst[:num_pos] dst_neg_index = inverse_dst[num_pos:] else: src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = ( input_id, src_index, dst_pos_index, dst_neg_index, src_time, ) # Homogeneous Neighborhood Sampling ####################################### else: seed = torch.cat([src, dst], dim=0) seed_time = None if not disjoint: seed, inverse_seed = seed.unique(return_inverse=True) if edge_label_time is not None: # Always disjoint. seed_time = torch.cat([src_time, dst_time]) out = sample_fn(seed, seed_time) # Enhance `out` by label information ################################## if neg_sampling is None or neg_sampling.is_binary(): if disjoint: out.batch = out.batch % num_pos edge_label_index = torch.arange(seed.numel()).view(2, -1) else: edge_label_index = inverse_seed.view(2, -1) out.metadata = (input_id, edge_label_index, edge_label, src_time) elif neg_sampling.is_triplet(): if disjoint: out.batch = out.batch % num_pos src_index = torch.arange(num_pos) dst_pos_index = torch.arange(num_pos, 2 * num_pos) # `dst_neg_index` needs to be offset such that indices with # offset `num_pos` belong to the same triplet: dst_neg_index = torch.arange(2 * num_pos, seed.numel()) dst_neg_index = dst_neg_index.view(-1, num_pos).t() else: src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = ( input_id, src_index, dst_pos_index, dst_neg_index, src_time, ) return out def neg_sample( seed: Tensor, neg_sampling: NegativeSampling, num_nodes: int, seed_time: Optional[Tensor], node_time: Optional[Tensor], endpoint: Literal['str', 'dst'], ) -> Tensor: num_neg = math.ceil(seed.numel() * neg_sampling.amount) # TODO: Do not sample false negatives. if node_time is None: return neg_sampling.sample(num_neg, endpoint, num_nodes) # If we are in a temporal-sampling scenario, we need to respect the # timestamp of the given nodes we can use as negative examples. # That is, we can only sample nodes for which `node_time <= seed_time`. # For now, we use a greedy algorithm which randomly samples negative # nodes and discard any which do not respect the temporal constraint. # We iteratively repeat this process until we have sampled a valid node for # each seed. # TODO See if this greedy algorithm here can be improved. assert seed_time is not None num_samples = math.ceil(neg_sampling.amount) seed_time = seed_time.view(1, -1).expand(num_samples, -1) out = neg_sampling.sample(num_samples * seed.numel(), endpoint, num_nodes) out = out.view(num_samples, seed.numel()) mask = node_time[out] > seed_time # holds all invalid samples. neg_sampling_complete = False for i in range(5): # pragma: no cover num_invalid = int(mask.sum()) if num_invalid == 0: neg_sampling_complete = True break # Greedily search for alternative negatives. out[mask] = tmp = neg_sampling.sample(num_invalid, endpoint, num_nodes) mask[mask.clone()] = node_time[tmp] >= seed_time[mask] if not neg_sampling_complete: # pragma: no cover # Not much options left. In that case, we set remaining negatives # to the node with minimum timestamp. out[mask] = node_time.argmin() return out.view(-1)[:num_neg]