torch_geometric.transforms.RandomNodeSplit

class RandomNodeSplit(split: str = 'train_rest', num_splits: int = 1, num_train_per_class: int = 20, num_val: Union[int, float] = 500, num_test: Union[int, float] = 1000, key: Optional[str] = 'y')[source]

Bases: BaseTransform

Performs a node-level random split by adding train_mask, val_mask and test_mask attributes to the Data or HeteroData object (functional name: random_node_split).

Parameters:
  • split (str, optional) – The type of dataset split ("train_rest", "test_rest", "random"). If set to "train_rest", all nodes except those in the validation and test sets will be used for training (as in the “FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling” paper). If set to "test_rest", all nodes except those in the training and validation sets will be used for test (as in the “Pitfalls of Graph Neural Network Evaluation” paper). If set to "random", train, validation, and test sets will be randomly generated, according to num_train_per_class, num_val and num_test (as in the “Semi-supervised Classification with Graph Convolutional Networks” paper). (default: "train_rest")

  • num_splits (int, optional) – The number of splits to add. If bigger than 1, the shape of masks will be [num_nodes, num_splits], and [num_nodes] otherwise. (default: 1)

  • num_train_per_class (int, optional) – The number of training samples per class in case of "test_rest" and "random" split. (default: 20)

  • num_val (int or float, optional) – The number of validation samples. If float, it represents the ratio of samples to include in the validation set. (default: 500)

  • num_test (int or float, optional) – The number of test samples in case of "train_rest" and "random" split. If float, it represents the ratio of samples to include in the test set. (default: 1000)

  • key (str, optional) – The name of the attribute holding ground-truth labels. By default, will only add node-level splits for node-level storages in which key is present. (default: "y").