Source code for torch_geometric.transforms.random_link_split

from copy import copy
from typing import List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.data.storage import EdgeStorage
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import EdgeType
from torch_geometric.utils import negative_sampling


[docs]@functional_transform('random_link_split') class RandomLinkSplit(BaseTransform): r"""Performs an edge-level random split into training, validation and test sets of a :class:`~torch_geometric.data.Data` or a :class:`~torch_geometric.data.HeteroData` object (functional name: :obj:`random_link_split`). The split is performed such that the training split does not include edges in validation and test splits; and the validation split does not include edges in the test split. .. code-block:: from torch_geometric.transforms import RandomLinkSplit transform = RandomLinkSplit(is_undirected=True) train_data, val_data, test_data = transform(data) Args: num_val (int or float, optional): The number of validation edges. If set to a floating-point value in :math:`[0, 1]`, it represents the ratio of edges to include in the validation set. (default: :obj:`0.1`) num_test (int or float, optional): The number of test edges. If set to a floating-point value in :math:`[0, 1]`, it represents the ratio of edges to include in the test set. (default: :obj:`0.2`) is_undirected (bool): If set to :obj:`True`, the graph is assumed to be undirected, and positive and negative samples will not leak (reverse) edge connectivity across different splits. (default: :obj:`False`) key (str, optional): The name of the attribute holding ground-truth labels. If :obj:`data[key]` does not exist, it will be automatically created and represents a binary classification task (:obj:`1` = edge, :obj:`0` = no edge). If :obj:`data[key]` exists, it has to be a categorical label from :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. (default: :obj:`"edge_label"`) split_labels (bool, optional): If set to :obj:`True`, will split positive and negative labels and save them in distinct attributes :obj:`"pos_edge_label"` and :obj:`"neg_edge_label"`, respectively. (default: :obj:`False`) add_negative_train_samples (bool, optional): Whether to add negative training samples for link prediction. If the model already performs negative sampling, then the option should be set to :obj:`False`. Otherwise, the added negative samples will be the same across training iterations unless negative sampling is performed again. (default: :obj:`True`) neg_sampling_ratio (float, optional): The ratio of sampled negative edges to the number of positive edges. (default: :obj:`1.0`) disjoint_train_ratio (int or float, optional): If set to a value greater than :obj:`0.0`, training edges will not be shared for message passing and supervision. Instead, :obj:`disjoint_train_ratio` edges are used as ground-truth labels for supervision during training. (default: :obj:`0.0`) edge_types (Tuple[EdgeType] or List[EdgeType], optional): The edge types used for performing edge-level splitting in case of operating on :class:`~torch_geometric.data.HeteroData` objects. (default: :obj:`None`) rev_edge_types (Tuple[EdgeType] or List[Tuple[EdgeType]], optional): The reverse edge types of :obj:`edge_types` in case of operating on :class:`~torch_geometric.data.HeteroData` objects. This will ensure that edges of the reverse direction will be splitted accordingly to prevent any data leakage. Can be :obj:`None` in case no reverse connection exists. (default: :obj:`None`) """ def __init__( self, num_val: Union[int, float] = 0.1, num_test: Union[int, float] = 0.2, is_undirected: bool = False, key: str = 'edge_label', split_labels: bool = False, add_negative_train_samples: bool = True, neg_sampling_ratio: float = 1.0, disjoint_train_ratio: Union[int, float] = 0.0, edge_types: Optional[Union[EdgeType, List[EdgeType]]] = None, rev_edge_types: Optional[Union[EdgeType, List[EdgeType]]] = None, ): if isinstance(edge_types, list): if rev_edge_types is None: rev_edge_types = [None] * len(edge_types) assert isinstance(rev_edge_types, list) assert len(edge_types) == len(rev_edge_types) self.num_val = num_val self.num_test = num_test self.is_undirected = is_undirected self.key = key self.split_labels = split_labels self.add_negative_train_samples = add_negative_train_samples self.neg_sampling_ratio = neg_sampling_ratio self.disjoint_train_ratio = disjoint_train_ratio self.edge_types = edge_types self.rev_edge_types = rev_edge_types def __call__(self, data: Union[Data, HeteroData]): edge_types = self.edge_types rev_edge_types = self.rev_edge_types train_data, val_data, test_data = copy(data), copy(data), copy(data) if isinstance(data, HeteroData): if edge_types is None: raise ValueError( "The 'RandomLinkSplit' transform expects 'edge_types' to" "be specified when operating on 'HeteroData' objects") if not isinstance(edge_types, list): edge_types = [edge_types] rev_edge_types = [rev_edge_types] stores = [data[edge_type] for edge_type in edge_types] train_stores = [train_data[edge_type] for edge_type in edge_types] val_stores = [val_data[edge_type] for edge_type in edge_types] test_stores = [test_data[edge_type] for edge_type in edge_types] else: rev_edge_types = [None] stores = [data._store] train_stores = [train_data._store] val_stores = [val_data._store] test_stores = [test_data._store] for item in zip(stores, train_stores, val_stores, test_stores, rev_edge_types): store, train_store, val_store, test_store, rev_edge_type = item is_undirected = self.is_undirected is_undirected &= not store.is_bipartite() is_undirected &= rev_edge_type is None edge_index = store.edge_index if is_undirected: mask = edge_index[0] <= edge_index[1] perm = mask.nonzero(as_tuple=False).view(-1) perm = perm[torch.randperm(perm.size(0), device=perm.device)] else: device = edge_index.device perm = torch.randperm(edge_index.size(1), device=device) num_val = self.num_val if isinstance(num_val, float): num_val = int(num_val * perm.numel()) num_test = self.num_test if isinstance(num_test, float): num_test = int(num_test * perm.numel()) num_train = perm.numel() - num_val - num_test if num_train <= 0: raise ValueError("Insufficient number of edges for training") train_edges = perm[:num_train] val_edges = perm[num_train:num_train + num_val] test_edges = perm[num_train + num_val:] train_val_edges = perm[:num_train + num_val] num_disjoint = self.disjoint_train_ratio if isinstance(num_disjoint, float): num_disjoint = int(num_disjoint * train_edges.numel()) if num_train - num_disjoint <= 0: raise ValueError("Insufficient number of edges for training") # Create data splits: self._split(train_store, train_edges[num_disjoint:], is_undirected, rev_edge_type) self._split(val_store, train_edges, is_undirected, rev_edge_type) self._split(test_store, train_val_edges, is_undirected, rev_edge_type) # Create negative samples: num_neg_train = 0 if self.add_negative_train_samples: if num_disjoint > 0: num_neg_train = int(num_disjoint * self.neg_sampling_ratio) else: num_neg_train = int(num_train * self.neg_sampling_ratio) num_neg_val = int(num_val * self.neg_sampling_ratio) num_neg_test = int(num_test * self.neg_sampling_ratio) num_neg = num_neg_train + num_neg_val + num_neg_test size = store.size() if store._key is None or store._key[0] == store._key[-1]: size = size[0] neg_edge_index = negative_sampling(edge_index, size, num_neg_samples=num_neg, method='sparse') # Create labels: if num_disjoint > 0: train_edges = train_edges[:num_disjoint] self._create_label( store, train_edges, neg_edge_index[:, num_neg_val + num_neg_test:], out=train_store, ) self._create_label( store, val_edges, neg_edge_index[:, :num_neg_val], out=val_store, ) self._create_label( store, test_edges, neg_edge_index[:, num_neg_val:num_neg_val + num_neg_test], out=test_store, ) return train_data, val_data, test_data def _split(self, store: EdgeStorage, index: Tensor, is_undirected: bool, rev_edge_type: EdgeType): for key, value in store.items(): if key == 'edge_index': continue if store.is_edge_attr(key): value = value[index] if is_undirected: value = torch.cat([value, value], dim=0) store[key] = value edge_index = store.edge_index[:, index] if is_undirected: edge_index = torch.cat([edge_index, edge_index.flip([0])], dim=-1) store.edge_index = edge_index if rev_edge_type is not None: rev_store = store._parent()[rev_edge_type] for key in rev_store.keys(): if key not in store: del rev_store[key] # We delete all outdated attributes. elif key == 'edge_index': rev_store.edge_index = store.edge_index.flip([0]) else: rev_store[key] = store[key] return store def _create_label(self, store: EdgeStorage, index: Tensor, neg_edge_index: Tensor, out: EdgeStorage): edge_index = store.edge_index[:, index] if hasattr(store, self.key): edge_label = store[self.key] edge_label = edge_label[index] # Increment labels by one. Note that there is no need to increment # in case no negative edges are added. if neg_edge_index.numel() > 0: assert edge_label.dtype == torch.long assert edge_label.size(0) == store.edge_index.size(1) edge_label.add_(1) if hasattr(out, self.key): delattr(out, self.key) else: edge_label = torch.ones(index.numel(), device=index.device) if neg_edge_index.numel() > 0: neg_edge_label = edge_label.new_zeros((neg_edge_index.size(1), ) + edge_label.size()[1:]) if self.split_labels: out[f'pos_{self.key}'] = edge_label out[f'pos_{self.key}_index'] = edge_index if neg_edge_index.numel() > 0: out[f'neg_{self.key}'] = neg_edge_label out[f'neg_{self.key}_index'] = neg_edge_index else: if neg_edge_index.numel() > 0: edge_label = torch.cat([edge_label, neg_edge_label], dim=0) edge_index = torch.cat([edge_index, neg_edge_index], dim=-1) out[self.key] = edge_label out[f'{self.key}_index'] = edge_index return out def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_val={self.num_val}, ' f'num_test={self.num_test})')