Source code for torch_geometric.transforms.random_link_split

from typing import Union, Tuple

import copy

import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops, negative_sampling
from torch_geometric.transforms import BaseTransform


[docs]class RandomLinkSplit(BaseTransform): r"""Performs an edge-level random split into training, validation and test sets. The split is performed such that the training split does not include edges in validation and test splits; and the validation split does not include edges in the test split. .. code-block:: from torch_geometric.transforms import RandomLinkSplit transform = RandomLinkSplit(is_undirected=True) train_data, val_data, test_data = transform(data) Args: num_val (int or float, optional): The number of validation edges. If set to a floating-point value in :math:`[0, 1]`, it represents the ratio of edges to include in the validation set. (default: :obj:`0.1`) num_test (int or float, optional): The number of test edges. If set to a floating-point value in :math:`[0, 1]`, it represents the ratio of edges to include in the test set. (default: :obj:`0.2`) is_undirected (bool): If set to :obj:`True`, the graph is assumed to be undirected, and positive and negative samples will not leak (reverse) edge connectivity across different splits. (default: :obj:`False`) key (str, optional): The name of the attribute holding ground-truth labels. If :obj:`data[key]` does not exist, it will be automatically created and represents a binary classification task (:obj:`1` = edge, :obj:`0` = no edge). If :obj:`data[key]` exists, it has to be a categorical label from :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. (default: :obj:`"edge_label"`) split_labels (bool, optional): If set to :obj:`True`, will split positive and negative labels and save them in distinct attributes :obj:`"pos_edge_label"` and :obj:`"neg_edge_label"`, respectively. (default: :obj:`False`) add_negative_train_samples (bool, optional): Whether to add negative training samples for link prediction. If the model already performs negative sampling, then the option should be set to :obj:`False`. Otherwise, the added negative samples will be the same across training iterations unless negative sampling is performed again. (default: :obj:`True`) neg_sampling_ratio: (float, optional): The ratio of sampled negative edges to the number of positive edges. (default: :obj:`1.0`) """ def __init__( self, num_val: Union[int, float] = 0.1, num_test: Union[int, float] = 0.2, is_undirected: bool = False, key: str = 'edge_label', split_labels: bool = False, add_negative_train_samples: bool = True, neg_sampling_ratio: float = 1.0, ): self.num_val = num_val self.num_test = num_test self.is_undirected = is_undirected self.key = key self.split_labels = split_labels self.add_negative_train_samples = add_negative_train_samples self.neg_sampling_ratio = neg_sampling_ratio def __call__(self, data: Data) -> Tuple[Data, Data, Data]: perm = torch.randperm(data.num_edges, device=data.edge_index.device) if self.is_undirected: perm = perm[data.edge_index[0] <= data.edge_index[1]] num_val, num_test = self.num_val, self.num_test if isinstance(num_val, float): num_val = int(num_val * perm.numel()) if isinstance(num_test, float): num_test = int(num_test * perm.numel()) num_train = perm.numel() - num_val - num_test if num_train <= 0: raise ValueError("Insufficient number of edges for training.") train_edges = perm[:num_train] val_edges = perm[num_train:num_train + num_val] test_edges = perm[num_train + num_val:] train_val_edges = perm[:num_train + num_val] # Create data splits: train_data = self._split_data(data, train_edges) val_data = self._split_data(data, train_edges) test_data = self._split_data(data, train_val_edges) # Create negative samples: num_neg_train = 0 if self.add_negative_train_samples: num_neg_train = int(num_train * self.neg_sampling_ratio) num_neg_val = int(num_val * self.neg_sampling_ratio) num_neg_test = int(num_test * self.neg_sampling_ratio) num_neg = num_neg_train + num_neg_val + num_neg_test neg_edge_index = negative_sampling( add_self_loops(data.edge_index)[0], num_nodes=data.num_nodes, num_neg_samples=num_neg, method='sparse') # Create labels: self._create_label( data, train_edges, neg_edge_index[:, num_neg_val + num_neg_test:], out=train_data, ) self._create_label( data, val_edges, neg_edge_index[:, :num_neg_val], out=val_data, ) self._create_label( data, test_edges, neg_edge_index[:, num_neg_val:num_neg_val + num_neg_test], out=test_data, ) return train_data, val_data, test_data def _split(self, edge_index: Tensor, index: Tensor) -> Tensor: edge_index = edge_index[:, index] if self.is_undirected: edge_index = torch.cat([edge_index, edge_index.flip([0])], dim=-1) return edge_index def _split_data(self, data: Data, index: Tensor) -> Data: num_edges = data.num_edges data = copy.copy(data) data.edge_index = self._split(data.edge_index, index) for key, value in data.items(): if isinstance(value, Tensor) and value.size(0) == num_edges: value = value[index] if self.is_undirected: value = torch.cat([value, value], dim=0) data[key] = value return data def _create_label(self, data: Data, index: Tensor, neg_edge_index: Tensor, out: Data): edge_index = data.edge_index[:, index] if hasattr(data, self.key): edge_label = data[self.key] assert edge_label.dtype == torch.long and edge_label.dim() == 1 edge_label = edge_label[index].add_(1) delattr(data, self.key) else: edge_label = torch.ones(index.numel(), device=index.device) if neg_edge_index.numel() > 0: neg_edge_label = edge_label.new_zeros(neg_edge_index.size(1)) if self.split_labels: out[f'pos_{self.key}'] = edge_label out[f'pos_{self.key}_index'] = edge_index if neg_edge_index.numel() > 0: out[f'neg_{self.key}'] = neg_edge_label out[f'neg_{self.key}_index'] = neg_edge_index else: if neg_edge_index.numel() > 0: edge_label = torch.cat([edge_label, neg_edge_label], dim=0) edge_index = torch.cat([edge_index, neg_edge_index], dim=-1) out[self.key] = edge_label out[f'{self.key}_index'] = edge_index return out def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_val={self.num_val}, ' f'num_test={self.num_test})')