Source code for torch_geometric.experimental

from typing import List, Optional, Union

__experimental_flag__ = {}

Options = Optional[Union[str, List[str]]]


def get_options(options: Options) -> List[str]:
    if options is None:
        options = list(__experimental_flag__.keys())
    if isinstance(options, str):
        options = [options]
    return options


[docs]def is_experimental_mode_enabled(options: Options = None) -> bool: r"""Returns :obj:`True` if the experimental mode is enabled. See :class:`torch_geometric.experimental_mode` for a list of (optional) options.""" options = get_options(options) return all([__experimental_flag__[option] for option in options])
def set_experimental_mode_enabled(mode: bool, options: Options = None): for option in get_options(options): __experimental_flag__[option] = mode
[docs]class experimental_mode: r"""Context-manager that enables the experimental mode to test new but potentially unstable features. .. code-block:: python with torch_geometric.experimental_mode(): out = model(data.x, data.edge_index) Args: options (str or list, optional): Currently there are no experimental features. """ def __init__(self, options: Options = None): self.options = get_options(options) self.previous_state = { option: __experimental_flag__[option] for option in self.options } def __enter__(self): set_experimental_mode_enabled(True, self.options) def __exit__(self, *args) -> bool: for option, value in self.previous_state.items(): __experimental_flag__[option] = value
[docs]class set_experimental_mode: r"""Context-manager that sets the experimental mode on or off. :class:`set_experimental_mode` will enable or disable the experimental mode based on its argument :attr:`mode`. It can be used as a context-manager or as a function. See :class:`experimental_mode` above for more details. """ def __init__(self, mode: bool, options: Options = None): self.options = get_options(options) self.previous_state = { option: __experimental_flag__[option] for option in self.options } set_experimental_mode_enabled(mode, self.options) def __enter__(self): pass def __exit__(self, *args): for option, value in self.previous_state.items(): __experimental_flag__[option] = value