Source code for torch_geometric.nn.sequential

from typing import List, Union, Tuple, Callable

import os
import os.path as osp
from uuid import uuid1

import torch
from jinja2 import Template

from torch_geometric.nn.conv.utils.jit import class_from_module_repr


[docs]def Sequential( input_args: str, modules: List[Union[Tuple[Callable, str], Callable]], ) -> torch.nn.Module: r"""An extension of the :class:`torch.nn.Sequential` container in order to define a sequential GNN model. Since GNN operators take in multiple input arguments, :class:`torch_geometric.nn.Sequential` expects both global input arguments, and function header definitions of individual operators. If omitted, an intermediate module will operate on the *output* of its preceding module: .. code-block:: python from torch.nn import Linear, ReLU from torch_geometric.nn import Sequential, GCNConv model = Sequential('x, edge_index', [ (GCNConv(in_channels, 64), 'x, edge_index -> x'), ReLU(inplace=True), (GCNConv(64, 64), 'x, edge_index -> x'), ReLU(inplace=True), Linear(64, out_channels), ]) where ``'x, edge_index'`` defines the input arguments of :obj:`model`, and ``'x, edge_index -> x'`` defines the function header, *i.e.* input arguments *and* return types, of :class:`~torch_geometric.nn.conv.GCNConv`. In particular, this also allows to create more sophisticated models, such as utilizing :class:`~torch_geometric.nn.models.JumpingKnowledge`: .. code-block:: python from torch.nn import Linear, ReLU, Dropout from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge from torch_geometric.nn import global_mean_pool model = Sequential('x, edge_index, batch', [ (Dropout(p=0.5), 'x -> x'), (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'), ReLU(inplace=True), (GCNConv(64, 64), 'x1, edge_index -> x2'), ReLU(inplace=True), (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'), (JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'), (global_mean_pool, 'x, batch -> x'), Linear(2 * 64, dataset.num_classes), ]) Args: input_args (str): The input arguments of the model. modules ([(str, Callable) or Callable]): A list of modules (with optional function header definitions). """ input_args = [x.strip() for x in input_args.split(',')] # We require the first entry of the input list to define arguments: assert len(modules) > 0 and isinstance(modules[0], (tuple, list)) # A list holding the callable function and the input and output names: calls: List[Tuple[Callable, List[str], List[str]]] = [] for module in modules: if isinstance(module, (tuple, list)) and len(module) >= 2: module, desc = module[:2] in_desc, out_desc = parse_desc(desc) elif isinstance(module, (tuple, list)): module = module[0] in_desc = out_desc = calls[-1][2] else: in_desc = out_desc = calls[-1][2] calls.append((module, in_desc, out_desc)) root = os.path.dirname(osp.realpath(__file__)) with open(osp.join(root, 'sequential.jinja'), 'r') as f: template = Template(f.read()) cls_name = f'Sequential_{uuid1().hex[:6]}' module_repr = template.render( cls_name=cls_name, input_args=input_args, calls=calls, ) # Instantiate a class from the rendered module representation. module = class_from_module_repr(cls_name, module_repr)() for i, (submodule, _, _) in enumerate(calls): setattr(module, f'module_{i}', submodule) return module
def parse_desc(desc: str) -> Tuple[List[str], List[str]]: in_desc, out_desc = desc.split('->') in_desc = [x.strip() for x in in_desc.split(',')] out_desc = [x.strip() for x in out_desc.split(',')] return in_desc, out_desc