Dataset Splitting
Dataset splitting is a critical step in graph machine learning, where we divide our dataset into subsets for training, validation, and testing.
It ensures that our models are evaluated properly, preventing overfitting, and enabling generalization.
In this tutorial, we will explore the basics of dataset splitting, focusing on three fundamental tasks: node prediction, link prediction, and graph prediction.
We will introduce commonly used techniques, including RandomNodeSplit
and RandomLinkSplit
transformations.
Additionally, we will also cover how to create custom dataset splits beyond random ones.
Node Prediction
Note
In this section, we’ll learn how to use RandomNodeSplit
of PyG to randomly divide nodes into training, validation, and test sets.
A fully working example on dataset Planetoid
is available in examples/cora.py.
The RandomNodeSplit
is initialized to split nodes for both a PyG Data
and HeteroData
object.
split
defines the dataset’s split type.num_splits
defines the number of splits to add.num_train_per_class
defines the number of training nodes per class.num_val
defines the number of validation nodes after data splitting.num_test
defines the number of test nodes after data splitting.key
defines the name of the ground-truth labels.
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import RandomNodeSplit
x = torch.randn(8, 32) # Node features of shape [num_nodes, num_features]
y = torch.randint(0, 4, (8, )) # Node labels of shape [num_nodes]
edge_index = torch.tensor([
[2, 3, 3, 4, 5, 6, 7],
[0, 0, 1, 1, 2, 3, 4]],
)
# 0 1
# / \/ \
# 2 3 4
# | | |
# 5 6 7
data = Data(x=x, y=y, edge_index=edge_index)
node_transform = RandomNodeSplit(num_val=2, num_test=3)
node_splits = node_transform(data)
Here, we initialize a RandomNodeSplit
transformation to split the graph data by nodes.
After the transformation, train_mask
, valid_mask
and test_mask
will be attached to the graph data.
node_splits.train_mask
>>> tensor([ True, False, False, False, True, True, False, False])
node_splits.val_mask
>>> tensor([False, False, False, False, False, False, True, True])
node_splits.test_mask
>>> tensor([False, True, True, True, False, False, False, False])
In this example, there are 8 nodes, we want to sample 2 nodes for validation, 3 nodes for testing, and the rest for training.
Finally, we got node 0, 4, 5
as training set, node 6, 7
as validation set, and node 1, 2, 3
as test set.
Link Prediction
Note
In this section, we’ll learn how to use RandomLinkSplit
of PyG to randomly divide edges into training, validation, and test sets.
A fully working example on dataset Planetoid
is available in examples/link_pred.py.
The RandomLinkSplit
is initialized to split edges for both a PyG Data
and HeteroData
object.
num_val
defines the number of validation edges after data splitting.num_test
defines the number of test edges after data splitting.is_undirected
defines whether the graph is assumed as undirected.
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
x = torch.randn(8, 32) # Node features of shape [num_nodes, num_features]
y = torch.randint(0, 4, (8, )) # Node labels of shape [num_nodes]
edge_index = torch.tensor([
[2, 3, 3, 4, 5, 6, 7],
[0, 0, 1, 1, 2, 3, 4]],
)
edge_y = torch.tensor([0, 0, 0, 0, 1, 1, 1])
# 0 1
# / \/ \
# 2 3 4
# | | |
# 5 6 7
data = Data(x=x, y=y, edge_index=edge_index, edge_y=edge_y)
edge_transform = RandomLinkSplit(num_val=0.2, num_test=0.2, key='edge_y',
is_undirected=False, add_negative_train_samples=False)
train_data, val_data, test_data = edge_transform(data)
Similar to node splitting, we initialize a RandomLinkSplit
transformation to split the graph data by edges.
Below, we can see the splitting results.
train_data
>>> Data(x=[8, 32], edge_index=[2, 5], y=[8], edge_y=[5], edge_y_index=[2, 5])
val_data
>>> Data(x=[8, 32], edge_index=[2, 5], y=[8], edge_y=[2], edge_y_index=[2, 2])
test_data
>>> Data(x=[8, 32], edge_index=[2, 6], y=[8], edge_y=[2], edge_y_index=[2, 2])
train_data.edge_index
and val_data.edge_index
refers to the edges that are used for message passing.
As such, during training and validation, we are allowed to propagate information based on the training edges.
While during testing, we can propagate information based on the union of training and validation edges.
For evaluation and testing, val_data.edge_label_index
and test_data.edge_label_index
hold a batch of positive and negative samples that should be used to evaluate and test our model on.
Graph Prediction
Note
In this section, we’ll learn how to randomly divide graphs into training, validation, and test sets.
A fully working example on dataset PPI
is available in examples/ppi.py.
In graph prediction task, each graph is an independent sample.
Usually we need to divide a graph dataset according to a certain ratio.
PyG has provided some datasets that already contain corresponding indexes for training, validation and test, such as PPI
.
from torch_geometric.datasets import PPI
path = './data/PPI'
train_dataset = PPI(path, split='train')
val_dataset = PPI(path, split='val')
test_dataset = PPI(path, split='test')
In addition, we can also use scikit-learn
or numpy
to randomly divide PyG dataset.
Creating Custom Splits
If random splitting doesn’t suit our specific use case, then we can create custom node splits. This requirement generally occurs in real business scenarios. For example, there are large-scale heterogeneous graphs in e-commerce scenarios, and nodes can be used to represent users, products, merchants, etc. We may divide new and old users to evaluate the performance of the model on new users. Therefore, we’ll not post specific examples here for reference.