Source code for torch_geometric.transforms.random_node_split

from typing import Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.data.storage import NodeStorage
from torch_geometric.transforms import BaseTransform


[docs]@functional_transform('random_node_split') class RandomNodeSplit(BaseTransform): r"""Performs a node-level random split by adding :obj:`train_mask`, :obj:`val_mask` and :obj:`test_mask` attributes to the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object (functional name: :obj:`random_node_split`). Args: split (str, optional): The type of dataset split (:obj:`"train_rest"`, :obj:`"test_rest"`, :obj:`"random"`). If set to :obj:`"train_rest"`, all nodes except those in the validation and test sets will be used for training (as in the `"FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling" <https://arxiv.org/abs/1801.10247>`_ paper). If set to :obj:`"test_rest"`, all nodes except those in the training and validation sets will be used for test (as in the `"Pitfalls of Graph Neural Network Evaluation" <https://arxiv.org/abs/1811.05868>`_ paper). If set to :obj:`"random"`, train, validation, and test sets will be randomly generated, according to :obj:`num_train_per_class`, :obj:`num_val` and :obj:`num_test` (as in the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper). (default: :obj:`"train_rest"`) num_splits (int, optional): The number of splits to add. If bigger than :obj:`1`, the shape of masks will be :obj:`[num_nodes, num_splits]`, and :obj:`[num_nodes]` otherwise. (default: :obj:`1`) num_train_per_class (int, optional): The number of training samples per class in case of :obj:`"test_rest"` and :obj:`"random"` split. (default: :obj:`20`) num_val (int or float, optional): The number of validation samples. If float, it represents the ratio of samples to include in the validation set. (default: :obj:`500`) num_test (int or float, optional): The number of test samples in case of :obj:`"train_rest"` and :obj:`"random"` split. If float, it represents the ratio of samples to include in the test set. (default: :obj:`1000`) key (str, optional): The name of the attribute holding ground-truth labels. By default, will only add node-level splits for node-level storages in which :obj:`key` is present. (default: :obj:`"y"`). """ def __init__( self, split: str = "train_rest", num_splits: int = 1, num_train_per_class: int = 20, num_val: Union[int, float] = 500, num_test: Union[int, float] = 1000, key: Optional[str] = "y", ) -> None: assert split in ['train_rest', 'test_rest', 'random'] self.split = split self.num_splits = num_splits self.num_train_per_class = num_train_per_class self.num_val = num_val self.num_test = num_test self.key = key def forward( self, data: Union[Data, HeteroData], ) -> Union[Data, HeteroData]: for store in data.node_stores: if self.key is not None and not hasattr(store, self.key): continue train_masks, val_masks, test_masks = zip( *[self._split(store) for _ in range(self.num_splits)]) store.train_mask = torch.stack(train_masks, dim=-1).squeeze(-1) store.val_mask = torch.stack(val_masks, dim=-1).squeeze(-1) store.test_mask = torch.stack(test_masks, dim=-1).squeeze(-1) return data def _split(self, store: NodeStorage) -> Tuple[Tensor, Tensor, Tensor]: num_nodes = store.num_nodes assert num_nodes is not None train_mask = torch.zeros(num_nodes, dtype=torch.bool) val_mask = torch.zeros(num_nodes, dtype=torch.bool) test_mask = torch.zeros(num_nodes, dtype=torch.bool) if isinstance(self.num_val, float): num_val = round(num_nodes * self.num_val) else: num_val = self.num_val if isinstance(self.num_test, float): num_test = round(num_nodes * self.num_test) else: num_test = self.num_test if self.split == 'train_rest': perm = torch.randperm(num_nodes) val_mask[perm[:num_val]] = True test_mask[perm[num_val:num_val + num_test]] = True train_mask[perm[num_val + num_test:]] = True else: assert self.key is not None y = getattr(store, self.key) num_classes = int(y.max().item()) + 1 for c in range(num_classes): idx = (y == c).nonzero(as_tuple=False).view(-1) idx = idx[torch.randperm(idx.size(0))] idx = idx[:self.num_train_per_class] train_mask[idx] = True remaining = (~train_mask).nonzero(as_tuple=False).view(-1) remaining = remaining[torch.randperm(remaining.size(0))] val_mask[remaining[:num_val]] = True if self.split == 'test_rest': test_mask[remaining[num_val:]] = True elif self.split == 'random': test_mask[remaining[num_val:num_val + num_test]] = True return train_mask, val_mask, test_mask def __repr__(self) -> str: return f'{self.__class__.__name__}(split={self.split})'