Compiled Graph Neural Networks
torch.compile()
is the latest method to speed up your PyTorch code in torch >= 2.0.0
!
torch.compile()
makes PyTorch code run faster by JIT-compiling it into optimized kernels, all while required minimal code changes.
Under the hood, torch.compile()
captures PyTorch programs via TorchDynamo
, canonicalizes over 2,000 PyTorch operators via PrimTorch
, and finally generates fast code out of it across multiple accelerators and backends via the deep learning compiler TorchInductor
.
Note
See here for a general tutorial on how to leverage torch.compile()
, and here for a description of its interface.
In this tutorial, we show how to optimize your custom PyG model via torch.compile()
.
Introducing torch_geometric.compile()
By default, torch.compile()
struggles to optimize a custom PyG model since its underlying MessagePassing
interface is JIT-unfriendly due to its generality.
As such, in PyG 2.3, we introduce torch_geometric.compile()
, a wrapper around torch.compile()
with the same signature.
torch_geometric.compile()
applies further optimizations to make PyG models more compiler-friendly.
Specifically, it:
Temporarily disables the usage of the extension packages
torch_scatter
,torch_sparse
andpyg_lib
during GNN execution workflows (since these are not yet directly optimizable by PyTorch). From PyG 2.3 onwards, these packages are purely optional and not required anymore for running PyG models (butpyg_lib
may be required for graph sampling routines).Converts all instances of
MessagePassing
modules into their jittable instances (seetorch_geometric.nn.conv.MessagePassing.jittable()
)
Without these adjustments, torch.compile()
may currently fail to correctly optimize your PyG model.
We are working on fully relying on torch.compile()
for future releases.
Basic Usage
Leveraging torch_geometric.compile()
is as simple as the usage of torch.compile()
.
Once you have a PyG model defined, simply wrap it with torch_geometric.compile()
to obtain its optimized version:
import torch_geometric
from torch_geometric.nn import GraphSAGE
model = GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)
model = model.to(device)
model = torch_geometric.compile(model)
and execute it as usual:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root, name="Cora")
data = dataset[0].to(device)
out = model(data.x, data.edge_index)
We have incorporated multiple examples in examples/compile
that further show the practical usage of torch_geometric.compile()
:
Note that torch.compile(model, dynamic=True)()
does sadly not yet work for PyG models on PyTorch 2.0.
While static compilation via torch.compile(model, dynamic=False)()
works fine, it will re-compile the model everytime it sees an input with a different shape.
That currently does not play that nicely with the way PyG performs mini-batching, and will hence lead to major slow-downs.
We are working with the PyTorch team to fix this limitation (see this GitHub issue).
A temporary workaround is to utilize the torch_geometric.transforms.Pad
transformation to ensure that all inputs are of equal shape.
If you notice that compile()
fails for a certain PyG model, do not hesitate to reach out either on GitHub or Slack.
We are very eager to improve compile()
support across the whole PyG code base.
Benchmark
torch.compile()
works fantastically well for many PyG models.
Overall, we observe runtime improvements of nearly up to 300%.
Specifically, we benchmark GCN
, GraphSAGE
and GIN
and compare runtimes obtained from traditional eager mode and torch_geometric.compile()
.
We use a synthetic graph with 10,000 nodes and 200,000 edges, and a hidden feature dimensionality of 64.
We report runtimes over 500 optimization steps:
Model |
Mode |
Forward |
Backward |
Total |
Speedup |
---|---|---|---|---|---|
Eager |
2.6396s |
2.1697s |
4.8093s |
||
Compiled |
1.1082s |
0.5896s |
1.6978s |
2.83x |
|
Eager |
1.6023s |
1.6428s |
3.2451s |
||
Compiled |
0.7033s |
0.7465s |
1.4498s |
2.24x |
|
Eager |
1.6701s |
1.6990s |
3.3690s |
||
Compiled |
0.7320s |
0.7407s |
1.4727s |
2.29x |
To reproduce these results, run
python test/nn/models/test_basic_gnn.py
from the root folder of your checked out PyG repository from GitHub.