Multi-GPU Training in Pure PyTorch

For many large scale, real-world datasets, it may be necessary to scale-up training across multiple GPUs. This tutorial goes over how to set up a multi-GPU training pipeline in with via torch.nn.parallel.DistributedDataParallel, without the need for any other third-party libraries (such as ). Note that this approach is based on data-parallelism. This means that each GPU runs an identical copy of the model; you might want to look into PyTorch FSDP if you want to scale your model across devices. Data-parallelism allows you to increase the batch size of your model by aggregating gradients across GPUs and then sharing the same optimizer step within every model replica. This DDP+MNIST-tutorial by the Princeton University has some nice illustrations of the process.

Specifically this tutorial shows how to train a GraphSAGE GNN model on the Reddit dataset. For this, we will use torch.nn.parallel.DistributedDataParallel to scale-up training across all available GPUs. We will do this by spawning multiple processes from our code which will all execute the same function. Per process, we set up our model instance and feed data through it by utilizing the NeighborLoader. Gradients are synchronized by wrapping the model in torch.nn.parallel.DistributedDataParallel (as described in its official tutorial), which in turn relies on torch.distributed-IPC-facilities.


The complete script of this tutorial can be found at examples/multi_gpu/

Defining a Spawnable Runner

To create our training script, we use the -provided wrapper of the vanilla multiprocessing module. Here, the world_size corresponds to the number of GPUs we will be using at once. torch.multiprocessing.spawn() will take care of spawning world_size processes. Each process will load the same script as a module and subsequently execute the run()-function:

from torch_geometric.datasets import Reddit
import torch.multiprocessing as mp

def run(rank: int, world_size: int, dataset: Reddit):

if __name__ == '__main__':
    dataset = Reddit('./data/Reddit')
    world_size = torch.cuda.device_count()
    mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)

Note that we initialize the dataset before spawning any processes. With this, we only initialize the dataset once, and any data inside it will be automatically moved to shared memory via torch.multiprocessing such that processes do not need to create their own replica of the data. In addition, note how the run() function accepts rank as its first argument. This argument is not explicitly provided by us. It corresponds to the process ID (starting at 0) injected by . Later we will use this to select a unique GPU for every rank.

With this, we can start to implement our spawnable runner function. The first step is to initialize a process group with torch.distributed. To this point, processes are not aware of each other and we set a hardcoded server-address for rendezvous using the nccl protocol. More details can be found in the “Writing Distributed Applications with PyTorch” tutorial:

import os
import torch.distributed as dist
import torch

def run(rank: int, world_size: int, dataset: Reddit):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

Next, we split training indices into world_size many chunks for each GPU, and initialize the NeighborLoader class to only operate on its specific chunk of the training set:

from torch_geometric.loader import NeighborLoader

def run(rank: int, world_size: int, dataset: Reddit):

    data = dataset[0]

    train_index = data.train_mask.nonzero().view(-1)
    train_index = train_index.split(train_index.size(0) // world_size)[rank]

    train_loader = NeighborLoader(
        num_neighbors=[25, 10],

Note that our run() function is called for each rank, which means that each rank holds a separate NeighborLoader instance.

Similarly, we create a NeighborLoader instance for evaluation. For simplicity, we only do this on rank 0 such that computation of metrics does not need to communicate across different processes. We recommend taking a look at the torchmetrics package for distributed computation of metrics.

def run(rank: int, world_size: int, dataset: Reddit):

    if rank == 0:
        val_index = data.val_mask.nonzero().view(-1)
        val_loader = NeighborLoader(
            num_neighbors=[25, 10],

Now that we have our data loaders defined, we initialize our GraphSAGE model and wrap it inside torch.nn.parallel.DistributedDataParallel. We also move the model to its exclusive GPU using the rank as a shortcut for the full device identifier. The wrapper on our model manages communication between each rank and synchronizes gradients across all ranks before updating the model parameters across all ranks:

from torch.nn.parallel import DistributedDataParallel
from torch_geometric.nn import GraphSAGE

def run(rank: int, world_size: int, dataset: Reddit):

    model = GraphSAGE(
    model = DistributedDataParallel(model, device_ids=[rank])

Finally, we can set up our optimizer and define our training loop, which follows a similar flow as usual single GPU training loops - the actual magic of gradient and model weight synchronization across different processes will happen behind the scenes within DistributedDataParallel:

import torch.nn.functional as F

def run(rank: int, world_size: int, dataset: Reddit):

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(1, 11):
        for batch in train_loader:
            batch =
            out = model(batch.x, batch.edge_index)[:batch.batch_size]
            loss = F.cross_entropy(out, batch.y[:batch.batch_size])

After each training epoch, we evaluate and report validation metrics. As previously mentioned, we do this on a single GPU only. To synchronize all processes and to ensure that the model weights have been updated, we need to call torch.distributed.barrier():


if rank == 0:
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')

if rank == 0:
    count = correct = 0
    with torch.no_grad():
        for batch in val_loader:
            batch =
            out = model(batch.x, batch.edge_index)[:batch.batch_size]
            pred = out.argmax(dim=-1)
            correct += (pred == batch.y[:batch.batch_size]).sum()
            count += batch.batch_size
    print(f'Validation Accuracy: {correct/count:.4f}')


After finishing training, we can clean up processes and destroy the process group via:


And that’s it. Putting it all together gives a working multi-GPU example that follows a training flow that is similar to single GPU training. You can run the shown tutorial by yourself by looking at examples/multi_gpu/