Source code for torch_geometric.transforms.random_node_split

from typing import Union, Tuple

import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform


[docs]class RandomNodeSplit(BaseTransform): r"""Performs a node-level random split by adding :obj:`train_mask`, :obj:`val_mask` and :obj:`test_mask` attributes to the :obj:`data` object. Args: split (string): The type of dataset split (:obj:`"train_rest"`, :obj:`"test_rest"`, :obj:`"random"`). If set to :obj:`"train_rest"`, all nodes except those in the validation and test sets will be used for training (as in the `"FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling" <https://arxiv.org/abs/1801.10247>`_ paper). If set to :obj:`"test_rest"`, all nodes except those in the training and validation sets will be used for test (as in the `"Pitfalls of Graph Neural Network Evaluation" <https://arxiv.org/abs/1811.05868>`_ paper). If set to :obj:`"random"`, train, validation, and test sets will be randomly generated, according to :obj:`num_train_per_class`, :obj:`num_val` and :obj:`num_test` (as in the `"Semi-supervised Classification with Graph Convolutional Networks" <https://arxiv.org/abs/1609.02907>`_ paper). (default: :obj:`"train_rest"`) num_splits (int, optional): The number of splits to add. If bigger than :obj:`1`, the shape of masks will be :obj:`[num_nodes, num_splits]`, and :obj:`[num_nodes]` otherwise. (default: :obj:`1`) num_train_per_class (int, optional): The number of training samples per class in case of :obj:`"test_rest"` and :obj:`"random"` split. (default: :obj:`20`) num_val (int or float, optional): The number of validation samples. If float, it represents the ratio of samples to include in the validation set. (default: :obj:`500`) num_test (int or float, optional): The number of test samples in case of :obj:`"train_rest"` and :obj:`"random"` split. If float, it represents the ratio of samples to include in the test set. (default: :obj:`1000`) """ def __init__( self, split: str = "train_rest", num_splits: int = 1, num_train_per_class: int = 20, num_val: Union[int, float] = 500, num_test: Union[int, float] = 1000, ): assert split in ['train_rest', 'test_rest', 'random'] self.split = split self.num_splits = num_splits self.num_train_per_class = num_train_per_class self.num_val = num_val self.num_test = num_test def __call__(self, data: Data) -> Data: train_masks, val_masks, test_masks = [], [], [] for _ in range(self.num_splits): train_mask, val_mask, test_mask = self._sample_split(data) train_masks.append(train_mask) val_masks.append(val_mask) test_masks.append(test_mask) data.train_mask = torch.stack(train_masks, dim=-1).squeeze(-1) data.val_mask = torch.stack(val_masks, dim=-1).squeeze(-1) data.test_mask = torch.stack(test_masks, dim=-1).squeeze(-1) return data def _sample_split(self, data: Data) -> Tuple[Tensor, Tensor, Tensor]: train_mask = torch.zeros(data.num_nodes, dtype=torch.bool) val_mask = torch.zeros(data.num_nodes, dtype=torch.bool) test_mask = torch.zeros(data.num_nodes, dtype=torch.bool) if isinstance(self.num_val, float): num_val = round(data.num_nodes * self.num_val) else: num_val = self.num_val if isinstance(self.num_test, float): num_test = round(data.num_nodes * self.num_test) else: num_test = self.num_test if self.split == 'train_rest': perm = torch.randperm(data.num_nodes) val_mask[perm[:num_val]] = True test_mask[perm[num_val:num_val + num_test]] = True train_mask[perm[num_val + num_test:]] = True else: num_classes = int(data.y.max().item()) + 1 for c in range(num_classes): idx = (data.y == c).nonzero(as_tuple=False).view(-1) idx = idx[torch.randperm(idx.size(0))] idx = idx[:self.num_train_per_class] train_mask[idx] = True remaining = (~train_mask).nonzero(as_tuple=False).view(-1) remaining = remaining[torch.randperm(remaining.size(0))] val_mask[remaining[:num_val]] = True if self.split == 'test_rest': test_mask[remaining[num_val:]] = True elif self.split == 'random': test_mask[remaining[num_val:num_val + num_test]] = True return train_mask, val_mask, test_mask def __repr__(self) -> str: return f'{self.__class__.__name__}(split={self.split})'