Multi-GPU Training in Pure PyTorch

For many large scale, real-world datasets, it may be necessary to scale-up training across multiple GPUs. This tutorial goes over how to set up a multi-GPU training and inference pipeline in with pure via torch.nn.parallel.DistributedDataParallel, without the need for any other third-party libraries (such as ).

In particular, this tutorials shows how to train a GraphSAGE GNN model on the Reddit dataset. For this, we utilize the NeighborLoader together with torch.nn.parallel.DistributedDataParallel to scale-up training across all available GPUs.

Note

A runnable example of this tutorial can be found at examples/multi_gpu/distributed_sampling.py.

Defining a Spawnable Runner

DistributedDataParallel (DDP) implements data parallelism at the module level which can run across multiple machines. Applications using DDP spawn multiple processes and create a single DDP instance per process. DDP processes can be placed on the same machine or across machines.

To create a DDP module, we first need to set up process groups properly and define a spawnable runner function. Here, the world_size corresponds to the number of GPUs we will be using at once. For each GPU, the process is labeled with a process ID which we call rank. torch.multiprocessing.spawn() will take care of spawning world_size many processes:

from torch_geometric.datasets import Reddit
import torch.multiprocessing as mp

def run(rank: int, world_size: int, dataset: Reddit):
    pass

if __name__ == '__main__':
    dataset = Reddit('./data/Reddit')
    world_size = torch.cuda.device_count()
    mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)

Note that we initialize the dataset before spawning any processes. With this, we only initialize the dataset once, and any data inside it will be automatically moved to shared memory via torch.multiprocessing such that processes do not need to create their own replica of the data.

With this, we can start to implement our spawnable runner function. The first step is to initialize torch.distributed. More details can be found in Writing Distributed Applications with PyTorch:

import os
import torch.distributed as dist
import torch

def run(rank: int, world_size: int, dataset: Reddit):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

Next, we split training indices into world_size many chunks for each GPU, and initialize the NeighborLoader class to only operate on its specific chunk of the training set:

from torch_geometric.loader import NeighborLoader

def run(rank: int, world_size: int, dataset: Reddit):
    ...

    data = dataset[0]

    train_index = data.train_mask.nonzero().view(-1)
    train_index = train_index.split(train_index.size(0) // world_size)[rank]

    train_loader = NeighborLoader(
        data,
        input_nodes=train_index,
        num_neighbors=[25, 10],
        batch_size=1024,
        num_workers=4,
        shuffle=True,
    )

Note that our run() function is called on each rank, which means that each rank holds a separate NeighborLoader instance.

Similarly, we create a NeighborLoader instance for evaluation. For simplicity, we only do this on rank 0 such that computation of metrics do not need to communicate across different processes. We recommend to take a look at the torchmetrics package for distributed computation of metrics.

def run(rank: int, world_size: int, dataset: Reddit):
    ...

    if rank == 0:
        val_index = data.val_mask.nonzero().view(-1)
        val_loader = NeighborLoader(
            data,
            input_nodes=val_index,
            num_neighbors=[25, 10],
            batch_size=1024,
            num_workers=4,
            shuffle=False,
        )

Now that we have our data loaders defined, we initialize our GraphSAGE model and wrap it inside ’s DistributedDataParallel. This wrapper on our model manages communication between each rank and reduces loss gradients from each process before updating the models parameters across all ranks:

from torch.nn.parallel import DistributedDataParallel
from torch_geometric.nn import GraphSAGE

def run(rank: int, world_size: int, dataset: Reddit):
    ...

    torch.manual_seed(12345)
    model = GraphSAGE(
        in_channels=dataset.num_features,
        hidden_channels=256,
        num_layers=2,
        out_channels=dataset.num_classes,
    ).to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])

Finally, we can set up our optimizer and define our training loop, which follows a similar flow as usual single GPU training loops - the actual magic of gradient and model weight synchronization across different processes will happen behind the scenes within DistributedDataParallel:

import torch.nn.functional as F

def run(rank: int, world_size: int, dataset: Reddit):
    ...

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(1, 11):
        model.train()
        for batch in train_loader:
            batch = batch.to_rank
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)[:batch.batch_size]
            loss = F.cross_entropy(out, batch.y[:batch.batch_size])
            loss.backward()
            optimizer.step()

After each training epoch, we evaluate and report validation metrics. As previously mentioned, we do this on a single GPU only. To synchronize all processes and to ensure that model weights have been updated, we need to call torch.distributed.barrier():

dist.barrier()

if rank == 0:
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')

if rank == 0:
    model.eval()
    count = correct = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(rank)
            out = model(batch.x, batch.edge_index)[:batch.batch_size]
            pred = out.argmax(dim=-1)
            correct += (pred == batch.y[:batch.batch_size]).sum()
            count += batch.batch_size
    print(f'Validation Accuracy: {correct/count:.4f}')

dist.barrier()

After finishing training, we can clean up processes and destroy the process group via:

dist.destroy_process_group()

And that’s it. Putting it all together gives a working multi-GPU example that follows a similar training flow than single GPU training. You can run the shown tutorial by yourself by looking at examples/multi_gpu/distributed_sampling.py.