from typing import List, Union, Tuple, Callable
import os
import os.path as osp
from uuid import uuid1
import torch
from jinja2 import Template
from torch_geometric.nn.conv.utils.jit import class_from_module_repr
[docs]def Sequential(
input_args: str,
modules: List[Union[Tuple[Callable, str], Callable]],
) -> torch.nn.Module:
r"""An extension of the :class:`torch.nn.Sequential` container in order to
define a sequential GNN model.
Since GNN operators take in multiple input arguments,
:class:`torch_geometric.nn.Sequential` expects both global input
arguments, and function header definitions of individual operators.
If omitted, an intermediate module will operate on the *output* of its
preceding module:
.. code-block:: python
from torch.nn import Linear, ReLU
from torch_geometric.nn import Sequential, GCNConv
model = Sequential('x, edge_index', [
(GCNConv(in_channels, 64), 'x, edge_index -> x'),
ReLU(inplace=True),
(GCNConv(64, 64), 'x, edge_index -> x'),
ReLU(inplace=True),
Linear(64, out_channels),
])
where ``'x, edge_index'`` defines the input arguments of :obj:`model`,
and ``'x, edge_index -> x'`` defines the function header, *i.e.* input
arguments *and* return types, of :class:`~torch_geometric.nn.conv.GCNConv`.
In particular, this also allows to create more sophisticated models,
such as utilizing :class:`~torch_geometric.nn.models.JumpingKnowledge`:
.. code-block:: python
from torch.nn import Linear, ReLU, Dropout
from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge
from torch_geometric.nn import global_mean_pool
model = Sequential('x, edge_index, batch', [
(Dropout(p=0.5), 'x -> x'),
(GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'),
ReLU(inplace=True),
(GCNConv(64, 64), 'x1, edge_index -> x2'),
ReLU(inplace=True),
(lambda x1, x2: [x1, x2], 'x1, x2 -> xs'),
(JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'),
(global_mean_pool, 'x, batch -> x'),
Linear(2 * 64, dataset.num_classes),
])
Args:
input_args (str): The input arguments of the model.
modules ([(str, Callable) or Callable]): A list of modules (with
optional function header definitions).
"""
input_args = [x.strip() for x in input_args.split(',')]
# We require the first entry of the input list to define arguments:
assert len(modules) > 0 and isinstance(modules[0], (tuple, list))
# A list holding the callable function and the input and output names:
calls: List[Tuple[Callable, List[str], List[str]]] = []
for module in modules:
if isinstance(module, (tuple, list)) and len(module) >= 2:
module, desc = module[:2]
in_desc, out_desc = parse_desc(desc)
elif isinstance(module, (tuple, list)):
module = module[0]
in_desc = out_desc = calls[-1][2]
else:
in_desc = out_desc = calls[-1][2]
calls.append((module, in_desc, out_desc))
root = os.path.dirname(osp.realpath(__file__))
with open(osp.join(root, 'sequential.jinja'), 'r') as f:
template = Template(f.read())
cls_name = f'Sequential_{uuid1().hex[:6]}'
module_repr = template.render(
cls_name=cls_name,
input_args=input_args,
calls=calls,
)
# Instantiate a class from the rendered module representation.
module = class_from_module_repr(cls_name, module_repr)()
for i, (submodule, _, _) in enumerate(calls):
setattr(module, f'module_{i}', submodule)
return module
def parse_desc(desc: str) -> Tuple[List[str], List[str]]:
in_desc, out_desc = desc.split('->')
in_desc = [x.strip() for x in in_desc.split(',')]
out_desc = [x.strip() for x in out_desc.split(',')]
return in_desc, out_desc