Source code for torch_geometric.nn.sequential

import os
import os.path as osp
from typing import Callable, List, Tuple, Union
from uuid import uuid1

import torch

from torch_geometric.nn.conv.utils.jit import class_from_module_repr

[docs]def Sequential( input_args: str, modules: List[Union[Tuple[Callable, str], Callable]], ) -> torch.nn.Module: r"""An extension of the :class:`torch.nn.Sequential` container in order to define a sequential GNN model. Since GNN operators take in multiple input arguments, :class:`torch_geometric.nn.Sequential` expects both global input arguments, and function header definitions of individual operators. If omitted, an intermediate module will operate on the *output* of its preceding module: .. code-block:: python from torch.nn import Linear, ReLU from torch_geometric.nn import Sequential, GCNConv model = Sequential('x, edge_index', [ (GCNConv(in_channels, 64), 'x, edge_index -> x'), ReLU(inplace=True), (GCNConv(64, 64), 'x, edge_index -> x'), ReLU(inplace=True), Linear(64, out_channels), ]) where :obj:`'x, edge_index'` defines the input arguments of :obj:`model`, and :obj:`'x, edge_index -> x'` defines the function header, *i.e.* input arguments *and* return types, of :class:`~torch_geometric.nn.conv.GCNConv`. In particular, this also allows to create more sophisticated models, such as utilizing :class:`~torch_geometric.nn.models.JumpingKnowledge`: .. code-block:: python from torch.nn import Linear, ReLU, Dropout from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge from torch_geometric.nn import global_mean_pool model = Sequential('x, edge_index, batch', [ (Dropout(p=0.5), 'x -> x'), (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'), ReLU(inplace=True), (GCNConv(64, 64), 'x1, edge_index -> x2'), ReLU(inplace=True), (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'), (JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'), (global_mean_pool, 'x, batch -> x'), Linear(2 * 64, dataset.num_classes), ]) Args: input_args (str): The input arguments of the model. modules ([(str, Callable) or Callable]): A list of modules (with optional function header definitions). Alternatively, an :obj:`OrderedDict` of modules (and function header definitions) can be passed. """ try: from jinja2 import Template except ImportError: raise ModuleNotFoundError( "No module named 'jinja2' found on this machine. " "Run 'pip install jinja2' to install the library.") input_args = [x.strip() for x in input_args.split(',')] if not isinstance(modules, dict): modules = {f'module_{i}': module for i, module in enumerate(modules)} # We require the first entry of the input list to define arguments: assert len(modules) > 0 first_module = list(modules.values())[0] assert isinstance(first_module, (tuple, list)) # A list holding the callable function and the input and output names: calls: List[Tuple[str, Callable, List[str], List[str]]] = [] for name, module in modules.items(): if isinstance(module, (tuple, list)) and len(module) >= 2: module, desc = module[:2] in_desc, out_desc = parse_desc(desc) elif isinstance(module, (tuple, list)): module = module[0] in_desc = out_desc = calls[-1][-1] else: in_desc = out_desc = calls[-1][-1] calls.append((name, module, in_desc, out_desc)) root = os.path.dirname(osp.realpath(__file__)) with open(osp.join(root, 'sequential.jinja'), 'r') as f: template = Template( cls_name = f'Sequential_{uuid1().hex[:6]}' module_repr = template.render( cls_name=cls_name, input_args=input_args, calls=calls, ) # Instantiate a class from the rendered module representation. module = class_from_module_repr(cls_name, module_repr)() module._names = list(modules.keys()) for name, submodule, _, _ in calls: setattr(module, name, submodule) return module
def parse_desc(desc: str) -> Tuple[List[str], List[str]]: in_desc, out_desc = desc.split('->') in_desc = [x.strip() for x in in_desc.split(',')] out_desc = [x.strip() for x in out_desc.split(',')] return in_desc, out_desc