Source code for torch_geometric.transforms.add_train_val_test_mask

from typing import Union

import torch

[docs]class AddTrainValTestMask(object): r"""Adds a node-level random split via :obj:`train_mask`, :obj:`val_mask` and :obj:`test_mask` attributes to the :obj:`data` object. Args: split (string): The type of dataset split (:obj:`"train_rest"`, :obj:`"test_rest"`, :obj:`"random"`). If set to :obj:`"train_rest"`, all nodes except those in the validation and test sets will be used for training (as in the `"FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling" <>`_ paper). If set to :obj:`"test_rest"`, all nodes except those in the training and validation sets will be used for test (as in the `"Pitfalls of Graph Neural Network Evaluation" <>`_ paper). If set to :obj:`"random"`, train, validation, and test sets will be randomly generated, according to :obj:`num_train_per_class`, :obj:`num_val` and :obj:`num_test` (as in the `"Semi-supervised Classification with Graph Convolutional Networks" <>`_ paper). num_splits (int, optional): The number of splits to add. If bigger than :obj:`1`, the shape of masks will be :obj:`[num_nodes, num_splits]`, and :obj:`[num_nodes]` otherwise. (default: :obj:`1`) num_train_per_class (int, optional): The number of training samples per class in case of :obj:`"test_rest"` and :obj:`"random"` split. (default: :obj:`20`) num_val (int or float, optional): The number of validation samples. If float, it represents the ratio of samples to include in the validation set. (default: :obj:`500`) num_test (int or float, optional): The number of test samples in case of :obj:`"train_rest"` and :obj:`"random"` split. If float, it represents the ratio of samples to include in the test set. (default: :obj:`1000`) """ def __init__( self, split: str, num_splits: int = 1, num_train_per_class: int = 20, num_val: Union[int, float] = 500, num_test: Union[int, float] = 1000, ): assert split in ['train_rest', 'test_rest', 'random'] self.split = split self.num_splits = num_splits self.num_train_per_class = num_train_per_class self.num_val = num_val self.num_test = num_test def __call__(self, data): train_masks, val_masks, test_masks = [], [], [] for _ in range(self.num_splits): train_mask, val_mask, test_mask = self.__sample_split__(data) train_masks.append(train_mask) val_masks.append(val_mask) test_masks.append(test_mask) data.train_mask = torch.stack(train_masks, dim=-1).squeeze(-1) data.val_mask = torch.stack(val_masks, dim=-1).squeeze(-1) data.test_mask = torch.stack(test_masks, dim=-1).squeeze(-1) return data def __sample_split__(self, data): train_mask = torch.zeros(data.num_nodes, dtype=torch.bool) val_mask = torch.zeros(data.num_nodes, dtype=torch.bool) test_mask = torch.zeros(data.num_nodes, dtype=torch.bool) if isinstance(self.num_val, float): num_val = round(data.num_nodes * self.num_val) else: num_val = self.num_val if isinstance(self.num_test, float): num_test = round(data.num_nodes * self.num_test) else: num_test = self.num_test if self.split == 'train_rest': perm = torch.randperm(data.num_nodes) val_mask[perm[:num_val]] = True test_mask[perm[num_val:num_val + num_test]] = True train_mask[perm[num_val + num_test:]] = True else: num_classes = int(data.y.max().item()) + 1 for c in range(num_classes): idx = (data.y == c).nonzero(as_tuple=False).view(-1) idx = idx[torch.randperm(idx.size(0))] idx = idx[:self.num_train_per_class] train_mask[idx] = True remaining = (~train_mask).nonzero(as_tuple=False).view(-1) remaining = remaining[torch.randperm(remaining.size(0))] val_mask[remaining[:num_val]] = True if self.split == 'test_rest': test_mask[remaining[num_val:]] = True elif self.split == 'random': test_mask[remaining[num_val:num_val + num_test]] = True return train_mask, val_mask, test_mask def __repr__(self): return '{}(split={})'.format(self.__class__.__name__, self.split)