Source code for torch_geometric.nn.sequential

import os.path as osp
import random
from typing import Callable, List, NamedTuple, Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.inspector import split, type_repr
from torch_geometric.template import module_from_template


class Child(NamedTuple):
    name: str
    module: Callable
    param_names: List[str]
    return_names: List[str]


[docs]def Sequential( input_args: str, modules: List[Union[Tuple[Callable, str], Callable]], ) -> torch.nn.Module: r"""An extension of the :class:`torch.nn.Sequential` container in order to define a sequential GNN model. Since GNN operators take in multiple input arguments, :class:`torch_geometric.nn.Sequential` additionally expects both global input arguments, and function header definitions of individual operators. If omitted, an intermediate module will operate on the *output* of its preceding module: .. code-block:: python from torch.nn import Linear, ReLU from torch_geometric.nn import Sequential, GCNConv model = Sequential('x, edge_index', [ (GCNConv(in_channels, 64), 'x, edge_index -> x'), ReLU(inplace=True), (GCNConv(64, 64), 'x, edge_index -> x'), ReLU(inplace=True), Linear(64, out_channels), ]) Here, :obj:`'x, edge_index'` defines the input arguments of :obj:`model`, and :obj:`'x, edge_index -> x'` defines the function header, *i.e.* input arguments *and* return types of :class:`~torch_geometric.nn.conv.GCNConv`. In particular, this also allows to create more sophisticated models, such as utilizing :class:`~torch_geometric.nn.models.JumpingKnowledge`: .. code-block:: python from torch.nn import Linear, ReLU, Dropout from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge from torch_geometric.nn import global_mean_pool model = Sequential('x, edge_index, batch', [ (Dropout(p=0.5), 'x -> x'), (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'), ReLU(inplace=True), (GCNConv(64, 64), 'x1, edge_index -> x2'), ReLU(inplace=True), (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'), (JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'), (global_mean_pool, 'x, batch -> x'), Linear(2 * 64, dataset.num_classes), ]) Args: input_args (str): The input arguments of the model. modules ([(Callable, str) or Callable]): A list of modules (with optional function header definitions). Alternatively, an :obj:`OrderedDict` of modules (and function header definitions) can be passed. """ signature = input_args.split('->') if len(signature) == 1: input_args = signature[0] return_type = type_repr(Tensor, globals()) elif len(signature) == 2: input_args, return_type = signature[0], signature[1].strip() else: raise ValueError(f"Failed to parse arguments (got '{input_args}')") input_types = split(input_args, sep=',') if len(input_types) == 0: raise ValueError(f"Failed to parse arguments (got '{input_args}')") if not isinstance(modules, dict): modules = {f'module_{i}': module for i, module in enumerate(modules)} if len(modules) == 0: raise ValueError("'Sequential' expected a non-empty list of modules") children: List[Child] = [] for i, (name, module) in enumerate(modules.items()): desc: Optional[str] = None if isinstance(module, (tuple, list)): if len(module) == 1: module = module[0] elif len(module) == 2: module, desc = module else: raise ValueError(f"Expected tuple of length 2 (got {module})") if i == 0 and desc is None: raise ValueError("Requires signature for first module") if not callable(module): raise ValueError(f"Expected callable module (got {module})") if desc is not None and not isinstance(desc, str): raise ValueError(f"Expected type hint representation (got {desc})") if desc is not None: signature = desc.split('->') if len(signature) != 2: raise ValueError(f"Failed to parse arguments (got '{desc}')") param_names = [v.strip() for v in signature[0].split(',')] return_names = [v.strip() for v in signature[1].split(',')] child = Child(name, module, param_names, return_names) else: param_names = children[-1].return_names child = Child(name, module, param_names, param_names) children.append(child) uid = '%06x' % random.randrange(16**6) root_dir = osp.dirname(osp.realpath(__file__)) module = module_from_template( module_name=f'torch_geometric.nn.sequential_{uid}', template_path=osp.join(root_dir, 'sequential.jinja'), tmp_dirname='sequential', # Keyword arguments: input_types=input_types, return_type=return_type, children=children, ) model = module.Sequential() model._module_names = [child.name for child in children] model._module_descs = [ f"{', '.join(child.param_names)} -> {', '.join(child.return_names)}" for child in children ] for child in children: setattr(model, child.name, child.module) return model