Source code for torch_geometric.nn.fx
from typing import Optional, Dict, Any
import copy
import torch
from torch.nn import Module, ModuleList, ModuleDict, Sequential
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.conv import MessagePassing
try:
from torch.fx import GraphModule, Graph, Node
except (ImportError, ModuleNotFoundError):
GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node'
[docs]class Transformer(object):
r"""A :class:`Transformer` executes an FX graph node-by-node, applies
transformations to each node, and produces a new :class:`torch.nn.Module`.
It exposes a :func:`transform` method that returns the transformed
:class:`~torch.nn.Module`.
:class:`Transformer` works entirely symbolically.
Methods in the :class:`Transformer` class can be overriden to customize the
behavior of transformation.
.. code-block:: none
transform()
+-- Iterate over each node in the graph
+-- placeholder()
+-- get_attr()
+-- call_function()
+-- call_method()
+-- call_module()
+-- call_message_passing_module()
+-- output()
+-- Erase unused nodes in the graph
+-- Iterate over each children module
+-- init_submodule()
In contrast to the :class:`torch.fx.Transformer` class, the
:class:`Transformer` exposes additional functionality:
#. It subdivides :func:`call_module` into nodes that call a regular
:class:`torch.nn.Module` (:func:`call_module`) or a
:class:`MessagePassing` module (:func:`call_message_passing_module`).
#. It allows to customize or initialize new children modules via
:func:`init_submodule`
#. It allows to infer whether a node returns node-level or edge-level
information via :meth:`is_edge_level`.
Args:
module (torch.nn.Module): The module to be transformed.
input_map (Dict[str, str], optional): A dictionary holding information
about the type of input arguments of :obj:`module.forward`.
For example, in case :obj:`arg` is a node-level argument, then
:obj:`input_map['arg'] = 'node'`, and
:obj:`input_map['arg'] = 'edge'` otherwise.
In case :obj:`input_map` is not further specified, will try to
automatically determine the correct type of input arguments.
(default: :obj:`None`)
debug: (bool, optional): If set to :obj:`True`, will perform
transformation in debug mode. (default: :obj:`False`)
"""
def __init__(
self,
module: Module,
input_map: Optional[Dict[str, str]] = None,
debug: bool = False,
):
self.module = module
self.gm = symbolic_trace(module)
self.input_map = input_map
self.debug = debug
# Methods to override #####################################################
# Internal functionality ##################################################
@property
def graph(self) -> Graph:
return self.gm.graph
[docs] def transform(self) -> GraphModule:
r"""Transforms :obj:`self.module` and returns a transformed
:class:`torch.fx.GraphModule`."""
if self.debug:
self.graph.print_tabular()
print()
code = self.graph.python_code('self')
print(code.src if hasattr(code, 'src') else code)
# We create a private dictionary `self._state` which holds information
# about whether a node returns node-level or edge-level information:
# `self._state[node.name] in { 'node', 'edge' }`
self._state = copy.copy(self.input_map or {})
# We iterate over each node and determine its output level
# (node-level, edge-level) by filling `self._state`:
for node in list(self.graph.nodes):
if node.op == 'placeholder':
if node.name not in self._state:
if 'edge' in node.name or 'adj' in node.name:
self._state[node.name] = 'edge'
else:
self._state[node.name] = 'node'
elif is_message_passing_op(self.module, node.op, node.target):
self._state[node.name] = 'node'
elif node.op in ['call_module', 'call_method', 'call_function']:
if self.has_edge_level_arg(node):
self._state[node.name] = 'edge'
else:
self._state[node.name] = 'node'
# We iterate over each node and may transform it:
for node in list(self.graph.nodes):
# Call the corresponding `Transformer` method for each `node.op`,
# e.g.: `call_module(...)`, `call_function(...)`, ...
op = node.op
if is_message_passing_op(self.module, op, node.target):
op = 'call_message_passing_module'
getattr(self, op)(node, node.target, node.name)
# Remove all unused nodes in the computation graph, i.e., all nodes
# which have been replaced by node type-wise or edge type-wise variants
# but which are still present in the computation graph.
# We do this by iterating over the computation graph in reversed order,
# and try to remove every node. This does only succeed in case there
# are no users of that node left in the computation graph.
for node in reversed(list(self.graph.nodes)):
try:
if node.op not in ['placeholder', 'output']:
self.graph.erase_node(node)
except RuntimeError:
pass
for target, submodule in dict(self.module._modules).items():
self.gm._modules[target] = self._init_submodule(submodule, target)
del self._state
if self.debug:
self.gm.graph.print_tabular()
print()
code = self.graph.python_code('self')
print(code.src if hasattr(code, 'src') else code)
self.gm.graph.lint()
self.gm.recompile()
return self.gm
def _init_submodule(self, module: Module, target: str) -> Module:
if isinstance(module, ModuleList) or isinstance(module, Sequential):
return ModuleList([
self._init_submodule(submodule, f'{target}.{i}')
for i, submodule in enumerate(module)
])
elif isinstance(module, ModuleDict):
return ModuleDict({
key: self._init_submodule(submodule, f'{target}.{key}')
for key, submodule in module.items()
})
else:
return self.init_submodule(module, target)
[docs] def has_edge_level_arg(self, node: Node) -> bool:
def _recurse(value: Any) -> bool:
if isinstance(value, Node):
return self.is_edge_level(value)
elif isinstance(value, dict):
return any([_recurse(v) for v in value.values()])
elif isinstance(value, (list, tuple)):
return any([_recurse(v) for v in value])
else:
return False
return (any([_recurse(value) for value in node.args])
or any([_recurse(value) for value in node.kwargs.values()]))
def find_by_name(self, name: str) -> Optional[Node]:
for node in self.graph.nodes:
if node.name == name:
return node
return None
def find_by_target(self, target: Any) -> Optional[Node]:
for node in self.graph.nodes:
if node.target == target:
return node
return None
[docs] def replace_all_uses_with(self, to_replace: Node, replace_with: Node):
def maybe_replace_node(n: Node) -> Node:
return replace_with if n == to_replace else n
node = replace_with.next
while node.op != 'root':
node.args = torch.fx.map_arg(node.args, maybe_replace_node)
node.kwargs = torch.fx.map_arg(node.kwargs, maybe_replace_node)
node = node.next
def symbolic_trace(
module: Module,
concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
class Tracer(torch.fx.Tracer):
def is_leaf_module(self, module: Module, *args, **kwargs) -> bool:
# We don't want to trace inside `MessagePassing` and lazy `Linear`
# modules, so we mark them as leaf modules.
return (isinstance(module, MessagePassing)
or isinstance(module, Linear)
or super().is_leaf_module(module, *args, **kwargs))
return GraphModule(module, Tracer().trace(module, concrete_args))
def get_submodule(module: Module, target: str) -> Module:
out = module
for attr in target.split('.'):
out = getattr(out, attr)
return out
def is_message_passing_op(module: Module, op: str, target: str) -> bool:
if op == 'call_module':
if isinstance(get_submodule(module, target), MessagePassing):
return True
return False