Source code for torch_geometric.nn.to_hetero_transformer

import copy
import warnings
from collections import defaultdict, deque
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch.nn import Module

from torch_geometric.nn.fx import Transformer, get_submodule
from torch_geometric.typing import EdgeType, Metadata, NodeType
from torch_geometric.utils.hetero import (
    check_add_self_loops,
    get_unused_node_types,
)

try:
    from torch.fx import Graph, GraphModule, Node
except (ImportError, ModuleNotFoundError, AttributeError):
    GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node'


[docs]def to_hetero(module: Module, metadata: Metadata, aggr: str = "sum", input_map: Optional[Dict[str, str]] = None, debug: bool = False) -> GraphModule: r"""Converts a homogeneous GNN model into its heterogeneous equivalent in which node representations are learned for each node type in :obj:`metadata[0]`, and messages are exchanged between each edge type in :obj:`metadata[1]`, as denoted in the `"Modeling Relational Data with Graph Convolutional Networks" <https://arxiv.org/abs/1703.06103>`_ paper: .. code-block:: python import torch from torch_geometric.nn import SAGEConv, to_hetero class GNN(torch.nn.Module): def __init__(self): self.conv1 = SAGEConv((-1, -1), 32) self.conv2 = SAGEConv((32, 32), 32) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index).relu() return x model = GNN() node_types = ['paper', 'author'] edge_types = [ ('paper', 'cites', 'paper'), ('paper' 'written_by', 'author'), ('author', 'writes', 'paper'), ] metadata = (node_types, edge_types) model = to_hetero(model, metadata) model(x_dict, edge_index_dict) where :obj:`x_dict` and :obj:`edge_index_dict` denote dictionaries that hold node features and edge connectivity information for each node type and edge type, respectively. The below illustration shows the original computation graph of the homogeneous model on the left, and the newly obtained computation graph of the heterogeneous model on the right: .. figure:: ../_figures/to_hetero.svg :align: center :width: 90% Transforming a model via :func:`to_hetero`. Here, each :class:`~torch_geometric.nn.conv.MessagePassing` instance :math:`f_{\theta}^{(\ell)}` is duplicated and stored in a set :math:`\{ f_{\theta}^{(\ell, r)} : r \in \mathcal{R} \}` (one instance for each relation in :math:`\mathcal{R}`), and message passing in layer :math:`\ell` is performed via .. math:: \mathbf{h}^{(\ell)}_v = \bigoplus_{r \in \mathcal{R}} f_{\theta}^{(\ell, r)} ( \mathbf{h}^{(\ell - 1)}_v, \{ \mathbf{h}^{(\ell - 1)}_w : w \in \mathcal{N}^{(r)}(v) \}), where :math:`\mathcal{N}^{(r)}(v)` denotes the neighborhood of :math:`v \in \mathcal{V}` under relation :math:`r \in \mathcal{R}`, and :math:`\bigoplus` denotes the aggregation scheme :attr:`aggr` to use for grouping node embeddings generated by different relations (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"` or :obj:`"mul"`). Args: module (torch.nn.Module): The homogeneous model to transform. metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. See :meth:`torch_geometric.data.HeteroData.metadata` for more information. aggr (string, optional): The aggregation scheme to use for grouping node embeddings generated by different relations. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"sum"`) input_map (Dict[str, str], optional): A dictionary holding information about the type of input arguments of :obj:`module.forward`. For example, in case :obj:`arg` is a node-level argument, then :obj:`input_map['arg'] = 'node'`, and :obj:`input_map['arg'] = 'edge'` otherwise. In case :obj:`input_map` is not further specified, will try to automatically determine the correct type of input arguments. (default: :obj:`None`) debug (bool, optional): If set to :obj:`True`, will perform transformation in debug mode. (default: :obj:`False`) """ transformer = ToHeteroTransformer(module, metadata, aggr, input_map, debug) return transformer.transform()
class ToHeteroTransformer(Transformer): aggrs = { 'sum': torch.add, # For 'mean' aggregation, we first sum up all feature matrices, and # divide by the number of matrices in a later step. 'mean': torch.add, 'max': torch.max, 'min': torch.min, 'mul': torch.mul, } def __init__( self, module: Module, metadata: Metadata, aggr: str = 'sum', input_map: Optional[Dict[str, str]] = None, debug: bool = False, ): super().__init__(module, input_map, debug) unused_node_types = get_unused_node_types(*metadata) if len(unused_node_types) > 0: warnings.warn( f"There exist node types ({unused_node_types}) whose " f"representations do not get updated during message passing " f"as they do not occur as destination type in any edge type. " f"This may lead to unexpected behaviour.") self.metadata = metadata self.aggr = aggr assert len(metadata) == 2 assert len(metadata[0]) > 0 and len(metadata[1]) > 0 assert aggr in self.aggrs.keys() def placeholder(self, node: Node, target: Any, name: str): # Adds a `get` call to the input dictionary for every node-type or # edge-type. if node.type is not None: Type = EdgeType if self.is_edge_level(node) else NodeType node.type = Dict[Type, node.type] self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: out = self.graph.create_node('call_method', target='get', args=(node, key), name=f'{name}__{key2str(key)}') self.graph.inserting_after(out) def get_attr(self, node: Node, target: Any, name: str): raise NotImplementedError def call_message_passing_module(self, node: Node, target: Any, name: str): # Add calls to edge type-wise `MessagePassing` modules and aggregate # the outputs to node type-wise embeddings afterwards. module = get_submodule(self.module, target) check_add_self_loops(module, self.metadata[1]) # Group edge-wise keys per destination: key_name, keys_per_dst = {}, defaultdict(list) for key in self.metadata[1]: keys_per_dst[key[-1]].append(key) key_name[key] = f'{name}__{key[-1]}{len(keys_per_dst[key[-1]])}' for dst, keys in dict(keys_per_dst).items(): # In case there is only a single edge-wise connection, there is no # need for any destination-wise aggregation, and we can already set # the intermediate variable name to the final output name. if len(keys) == 1: key_name[keys[0]] = f'{name}__{dst}' del keys_per_dst[dst] self.graph.inserting_after(node) for key in self.metadata[1]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_module', target=f'{target}.{key2str(key)}', args=args, kwargs=kwargs, name=key_name[key]) self.graph.inserting_after(out) # Perform destination-wise aggregation. # Here, we aggregate in pairs, popping the first two elements of # `keys_per_dst` and append the result to the list. for dst, keys in keys_per_dst.items(): queue = deque([key_name[key] for key in keys]) i = 1 while len(queue) >= 2: key1, key2 = queue.popleft(), queue.popleft() args = (self.find_by_name(key1), self.find_by_name(key2)) new_name = f'{name}__{dst}' if self.aggr == 'mean' or len(queue) > 0: new_name = f'{new_name}_{i}' out = self.graph.create_node('call_function', target=self.aggrs[self.aggr], args=args, name=new_name) self.graph.inserting_after(out) queue.append(new_name) i += 1 if self.aggr == 'mean': key = queue.popleft() out = self.graph.create_node( 'call_function', target=torch.div, args=(self.find_by_name(key), len(keys_per_dst[dst])), name=f'{name}__{dst}') self.graph.inserting_after(out) def call_global_pooling_module(self, node: Node, target: Any, name: str): # Add calls to node type-wise `GlobalPooling` modules and aggregate # the outputs to graph type-wise embeddings afterwards. self.graph.inserting_after(node) for key in self.metadata[0]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_module', target=f'{target}.{key2str(key)}', args=args, kwargs=kwargs, name=f'{node.name}__{key2str(key)}') self.graph.inserting_after(out) # Perform node-wise aggregation. queue = deque( [f'{node.name}__{key2str(key)}' for key in self.metadata[0]]) i = 1 while len(queue) >= 2: key1, key2 = queue.popleft(), queue.popleft() args = (self.find_by_name(key1), self.find_by_name(key2)) out = self.graph.create_node('call_function', target=self.aggrs[self.aggr], args=args, name=f'{name}_{i}') self.graph.inserting_after(out) queue.append(f'{name}_{i}') i += 1 if self.aggr == 'mean': key = queue.popleft() out = self.graph.create_node( 'call_function', target=torch.div, args=(self.find_by_name(key), len(self.metadata[0])), name=f'{name}_{i}') self.graph.inserting_after(out) self.replace_all_uses_with(node, out) def call_module(self, node: Node, target: Any, name: str): if self.is_graph_level(node): return # Add calls to node type-wise or edge type-wise modules. self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_module', target=f'{target}.{key2str(key)}', args=args, kwargs=kwargs, name=f'{name}__{key2str(key)}') self.graph.inserting_after(out) def call_method(self, node: Node, target: Any, name: str): if self.is_graph_level(node): return # Add calls to node type-wise or edge type-wise methods. self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_method', target=target, args=args, kwargs=kwargs, name=f'{name}__{key2str(key)}') self.graph.inserting_after(out) def call_function(self, node: Node, target: Any, name: str): if self.is_graph_level(node): return # Add calls to node type-wise or edge type-wise functions. self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: args, kwargs = self.map_args_kwargs(node, key) out = self.graph.create_node('call_function', target=target, args=args, kwargs=kwargs, name=f'{name}__{key2str(key)}') self.graph.inserting_after(out) def output(self, node: Node, target: Any, name: str): # Replace the output by dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node): if self.is_graph_level(value): return value return { key: self.find_by_name(f'{value.name}__{key2str(key)}') for key in self.metadata[int(self.is_edge_level(value))] } elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] if self.is_node_level(output): node.type = Dict[NodeType, node.type] elif self.is_edge_level(output): node.type = Dict[EdgeType, node.type] else: node.type = None node.args = (_recurse(node.args[0]), ) def init_submodule(self, module: Module, target: str) -> Module: # Replicate each module for each node type or edge type. has_node_level_target = bool( self.find_by_target(f'{target}.{key2str(self.metadata[0][0])}')) has_edge_level_target = bool( self.find_by_target(f'{target}.{key2str(self.metadata[1][0])}')) if not has_node_level_target and not has_edge_level_target: return module module_dict = torch.nn.ModuleDict() for key in self.metadata[int(has_edge_level_target)]: module_dict[key2str(key)] = copy.deepcopy(module) if len(self.metadata[int(has_edge_level_target)]) <= 1: continue if hasattr(module, 'reset_parameters'): module_dict[key2str(key)].reset_parameters() elif sum([p.numel() for p in module.parameters()]) > 0: warnings.warn( f"'{target}' will be duplicated, but its parameters " f"cannot be reset. To suppress this warning, add a " f"'reset_parameters()' method to '{target}'") return module_dict # Helper methods ########################################################## def map_args_kwargs(self, node: Node, key: Union[NodeType, EdgeType]) -> Tuple[Tuple, Dict]: def _recurse(value: Any) -> Any: if isinstance(value, Node): out = self.find_by_name(f'{value.name}__{key2str(key)}') if out is not None: return out elif isinstance(key, tuple) and key[0] == key[-1]: name = f'{value.name}__{key2str(key[0])}' return self.find_by_name(name) elif isinstance(key, tuple) and key[0] != key[-1]: return ( self.find_by_name(f'{value.name}__{key2str(key[0])}'), self.find_by_name(f'{value.name}__{key2str(key[-1])}'), ) else: raise NotImplementedError elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): return [_recurse(v) for v in value] elif isinstance(value, tuple): return tuple(_recurse(v) for v in value) else: return value args = tuple(_recurse(v) for v in node.args) kwargs = {k: _recurse(v) for k, v in node.kwargs.items()} return args, kwargs def key2str(key: Union[NodeType, EdgeType]) -> str: key = '__'.join(key) if isinstance(key, tuple) else key return key.replace(' ', '_').replace('-', '_').replace(':', '_')