import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
# TODO (matthias) This file currently requires manual imports to let
# TorchScript work on decorated functions. Not totally sure why :(
from torch_geometric.utils import * # noqa
__experimental_flag__: Dict[str, bool] = {
'disable_dynamic_shapes': False,
}
Options = Optional[Union[str, List[str]]]
def get_options(options: Options) -> List[str]:
if options is None:
options = list(__experimental_flag__.keys())
if isinstance(options, str):
options = [options]
return options
[docs]def is_experimental_mode_enabled(options: Options = None) -> bool:
r"""Returns :obj:`True` if the experimental mode is enabled. See
:class:`torch_geometric.experimental_mode` for a list of (optional)
options.
"""
if torch.jit.is_scripting() or torch.jit.is_tracing():
return False
options = get_options(options)
return all([__experimental_flag__[option] for option in options])
def set_experimental_mode_enabled(mode: bool, options: Options = None) -> None:
for option in get_options(options):
__experimental_flag__[option] = mode
[docs]class experimental_mode:
r"""Context-manager that enables the experimental mode to test new but
potentially unstable features.
.. code-block:: python
with torch_geometric.experimental_mode():
out = model(data.x, data.edge_index)
Args:
options (str or list, optional): Currently there are no experimental
features.
"""
def __init__(self, options: Options = None) -> None:
self.options = get_options(options)
self.previous_state = {
option: __experimental_flag__[option]
for option in self.options
}
def __enter__(self) -> None:
set_experimental_mode_enabled(True, self.options)
def __exit__(self, *args: Any) -> None:
for option, value in self.previous_state.items():
__experimental_flag__[option] = value
[docs]class set_experimental_mode:
r"""Context-manager that sets the experimental mode on or off.
:class:`set_experimental_mode` will enable or disable the experimental mode
based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
See :class:`experimental_mode` above for more details.
"""
def __init__(self, mode: bool, options: Options = None) -> None:
self.options = get_options(options)
self.previous_state = {
option: __experimental_flag__[option]
for option in self.options
}
set_experimental_mode_enabled(mode, self.options)
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> None:
for option, value in self.previous_state.items():
__experimental_flag__[option] = value
[docs]def disable_dynamic_shapes(required_args: List[str]) -> Callable:
r"""A decorator that disables the usage of dynamic shapes for the given
arguments, i.e., it will raise an error in case :obj:`required_args` are
not passed and needs to be automatically inferred.
"""
def decorator(func: Callable) -> Callable:
spec = inspect.getfullargspec(func)
required_args_pos: Dict[str, int] = {}
for arg_name in required_args:
if arg_name not in spec.args:
raise ValueError(f"The function '{func}' does not have a "
f"'{arg_name}' argument")
required_args_pos[arg_name] = spec.args.index(arg_name)
num_args = len(spec.args)
num_default_args = 0 if spec.defaults is None else len(spec.defaults)
num_positional_args = num_args - num_default_args
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not is_experimental_mode_enabled('disable_dynamic_shapes'):
return func(*args, **kwargs)
for required_arg in required_args:
index = required_args_pos[required_arg]
value: Optional[Any] = None
if index < len(args):
value = args[index]
elif required_arg in kwargs:
value = kwargs[required_arg]
elif num_default_args > 0:
assert spec.defaults is not None
value = spec.defaults[index - num_positional_args]
if value is None:
raise ValueError(f"Dynamic shapes disabled. Argument "
f"'{required_arg}' needs to be set")
return func(*args, **kwargs)
return wrapper
return decorator