torch_geometric.transforms.RandomNodeSplit
- class RandomNodeSplit(split: str = 'train_rest', num_splits: int = 1, num_train_per_class: int = 20, num_val: Union[int, float] = 500, num_test: Union[int, float] = 1000, key: Optional[str] = 'y')[source]
Bases:
BaseTransform
Performs a node-level random split by adding
train_mask
,val_mask
andtest_mask
attributes to theData
orHeteroData
object (functional name:random_node_split
).- Parameters:
split (str, optional) – The type of dataset split (
"train_rest"
,"test_rest"
,"random"
). If set to"train_rest"
, all nodes except those in the validation and test sets will be used for training (as in the “FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling” paper). If set to"test_rest"
, all nodes except those in the training and validation sets will be used for test (as in the “Pitfalls of Graph Neural Network Evaluation” paper). If set to"random"
, train, validation, and test sets will be randomly generated, according tonum_train_per_class
,num_val
andnum_test
(as in the “Semi-supervised Classification with Graph Convolutional Networks” paper). (default:"train_rest"
)num_splits (int, optional) – The number of splits to add. If bigger than
1
, the shape of masks will be[num_nodes, num_splits]
, and[num_nodes]
otherwise. (default:1
)num_train_per_class (int, optional) – The number of training samples per class in case of
"test_rest"
and"random"
split. (default:20
)num_val (int or float, optional) – The number of validation samples. If float, it represents the ratio of samples to include in the validation set. (default:
500
)num_test (int or float, optional) – The number of test samples in case of
"train_rest"
and"random"
split. If float, it represents the ratio of samples to include in the test set. (default:1000
)key (str, optional) – The name of the attribute holding ground-truth labels. By default, will only add node-level splits for node-level storages in which
key
is present. (default:"y"
).