Compiled Graph Neural Networks

torch.compile() is the latest method to speed up your code in torch >= 2.0.0! torch.compile() makes PyTorch code run faster by JIT-compiling it into optimized kernels, all while required minimal code changes.

Under the hood, torch.compile() captures programs via TorchDynamo, canonicalizes over 2,000 operators via PrimTorch, and finally generates fast code out of it across multiple accelerators and backends via the deep learning compiler TorchInductor.

Note

See here for a general tutorial on how to leverage torch.compile(), and here for a description of its interface.

In this tutorial, we show how to optimize your custom model via torch.compile().

torch_geometric.compile()

By default, torch.compile() struggles to optimize a custom model since its underlying MessagePassing interface is JIT-unfriendly due to its generality. As such, in , we introduce torch_geometric.compile(), a wrapper around torch.compile() with the same signature.

torch_geometric.compile() applies further optimizations to make models more compiler-friendly. Specifically, it:

  1. Temporarily disables the usage of the extension packages torch_scatter, torch_sparse and pyg_lib during GNN execution workflows (since these are not yet directly optimizable by ). From onwards, these packages are purely optional and not required anymore for running models (but pyg_lib may be required for graph sampling routines).

  2. Converts all instances of MessagePassing modules into their jittable instances (see torch_geometric.nn.conv.MessagePassing.jittable())

Without these adjustments, torch.compile() may currently fail to correctly optimize your model. We are working on fully relying on torch.compile() for future releases.

Basic Usage

Leveraging torch_geometric.compile() is as simple as the usage of torch.compile(). Once you have a model defined, simply wrap it with torch_geometric.compile() to obtain its optimized version:

import torch_geometric
from torch_geometric.nn import GraphSAGE

model = GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)
model = model.to(device)

model = torch_geometric.compile(model)

and execute it as usual:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root, name="Cora")
data = dataset[0].to(device)

out = model(data.x, data.edge_index)

Maximizing Performance

The torch.compile()/torch_geometric.compile() method provides two important arguments to be aware of:

  • Most of the mini-batches observed in are dynamic by nature, meaning that their shape varies across different mini-batches. For these scenarios, we can enforce dynamic shape tracing in via the dynamic=True argument:

    torch_geometric.compile(model, dynamic=True)
    

    With this, will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change across mini-batches changes. Note that when dynamic is set to False, will never generate dynamic kernels, leading to significant slowdowns in model execution on dynamic mini-batches. As such, you should only ever not specify dynamic=True when graph sizes are guaranteed to never change. Note that dynamic=True requires >= 2.1.0 to be installed.

  • In order to maximize speedup, graphs breaks in the compiled model should be limited. We can force compilation to raise an error upon the first graph break encountered by using the fullgraph=True argument:

    torch_geometric.compile(model, fullgraph=True)
    

    It is generally a good practice to confirm that your written model does not contain any graph breaks. Importantly, there exists a few operations in that will currently lead to graph breaks (but workaround exists), e.g.:

    1. global_mean_pool() (and other pooling operators) perform device synchronization in case the batch size size is not passed, leading to a graph break.

    2. remove_self_loops() and add_remaining_self_loops() mask the given edge_index, leading to a device synchronization to compute its final output shape. As such, we recommend to augment your graph before inputting it into your GNN, e.g., via the AddSelfLoops or GCNNorm transformations, and setting add_self_loops=False/normalize=False when initializing layers such as GCNNorm.

Exampe Scripts

We have incorporated multiple examples in examples/compile that further show the practical usage of torch_geometric.compile():

  1. Node Classification via GCN (dynamic=False)

  2. Graph Classification via GIN (dynamic=True)

If you notice that compile() fails for a certain model, do not hesitate to reach out either on GitHub or Slack. We are very eager to improve compile() support across the whole code base.

Benchmark

torch.compile() works fantastically well for many models. Overall, we observe runtime improvements of nearly up to 300%.

Specifically, we benchmark GCN, GraphSAGE and GIN and compare runtimes obtained from traditional eager mode and torch_geometric.compile(). We use a synthetic graph with 10,000 nodes and 200,000 edges, and a hidden feature dimensionality of 64. We report runtimes over 500 optimization steps:

Model

Mode

Forward

Backward

Total

Speedup

GCN

Eager

2.6396s

2.1697s

4.8093s

GCN

Compiled

1.1082s

0.5896s

1.6978s

2.83x

GraphSAGE

Eager

1.6023s

1.6428s

3.2451s

GraphSAGE

Compiled

0.7033s

0.7465s

1.4498s

2.24x

GIN

Eager

1.6701s

1.6990s

3.3690s

GIN

Compiled

0.7320s

0.7407s

1.4727s

2.29x

To reproduce these results, run

python test/nn/models/test_basic_gnn.py

from the root folder of your checked out repository from .