Compiled Graph Neural Networks

torch.compile() is the latest method to speed up your code in torch >= 2.0.0! torch.compile() makes PyTorch code run faster by JIT-compiling it into opimized kernels, all while required minimal code changes.

Under the hood, torch.compile() captures programs via TorchDynamo, canonicalizes over 2,000 operators via PrimTorch, and finally generates fast code out of it across multiple accelerators and backends via the deep learning compiler TorchInductor.

Note

See here for a general tutorial on how to leverage torch.compile(), and here for a description of its interface.

In this tutorial, we show how to optimize your custom model via torch.compile().

Introducing torch_geometric.compile()

By default, torch.compile() struggles to optimize a custom model since its underlying MessagePassing interface is JIT-unfriendly due to its generality. As such, in , we introduce torch_geometric.compile(), a wrapper around torch.compile() with the same signature.

torch_geometric.compile() applies further optimizations to make models more compiler-friendly. Specifically, it:

  1. Temporarily disables the usage of the extension packages torch_scatter, torch_sparse and pyg_lib during GNN execution workflows (since these are not yet directly optimizable by ). From onwards, these packages are purely optional and not required anymore for running models (but pyg_lib may be required for graph sampling routines).

  2. Converts all instances of MessagePassing modules into their jittable instances (see torch_geometric.nn.conv.MessagePassing.jittable())

Without these adjustments, torch.compile() may currently fail to correctly optimize your model. We are working on fully relying on torch.compile() for future releases.

Basic Usage

Leveraging torch_geometric.compile() is as simple as the usage of torch.compile(). Once you have a model defined, simply wrap it with torch_geometric.compile() to obtain its optimized version:

import torch_geometric
from torch_geometric.nn import GraphSAGE

model = GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)
model = model.to(device)

model = torch_geometric.compile(model)

and execute it as usual:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root, name="Cora")
data = dataset[0].to(device)

out = model(data.x, data.edge_index)

We have incorporated multiple examples in examples/compile that further show the practical usage of torch_geometric.compile():

  1. Node Classification via GCN

  2. Graph Classification via GIN

Note that torch.compile(model, dynamic=True)() does sadly not yet work for models on . While static compilation via torch.compile(model, dynamic=False)() works fine, it will re-compile the model everytime it sees an input with a different shape. That currently does not play that nicely with the way performs mini-batching, and will hence lead to major slow-downs. We are working with the team to fix this limitation (see this issue). A temporary workaround is to utilize the torch_geometric.transforms.Pad transformation to ensure that all inputs are of equal shape.

If you notice that compile() fails for a certain model, do not hesitate to reach out either on GitHub or Slack. We are very eager to improve compile() support across the whole code base.

Benchmark

torch.compile() works fantastically well for many models. Overall, we observe runtime improvements of nearly up to 300%.

Specifically, we benchmark GCN, GraphSAGE and GIN and compare runtimes obtained from traditional eager mode and torch_geometric.compile(). We use a synthetic graph with 10,000 nodes and 200,000 edges, and a hidden feature dimensionality of 64. We report runtimes over 500 optimization steps:

Model

Mode

Forward

Backward

Total

Speedup

GCN

Eager

2.6396s

2.1697s

4.8093s

GCN

Compiled

1.1082s

0.5896s

1.6978s

2.83x

GraphSAGE

Eager

1.6023s

1.6428s

3.2451s

GraphSAGE

Compiled

0.7033s

0.7465s

1.4498s

2.24x

GIN

Eager

1.6701s

1.6990s

3.3690s

GIN

Compiled

0.7320s

0.7407s

1.4727s

2.29x

To reproduce these results, run

python test/nn/models/test_basic_gnn.py

from the root folder of your checked out repository from .