Source code for torch_geometric.graphgym.register

from typing import Dict, Any, Union, Callable

act_dict: Dict[str, Any] = {}
node_encoder_dict: Dict[str, Any] = {}
edge_encoder_dict: Dict[str, Any] = {}
stage_dict: Dict[str, Any] = {}
head_dict: Dict[str, Any] = {}
layer_dict: Dict[str, Any] = {}
pooling_dict: Dict[str, Any] = {}
network_dict: Dict[str, Any] = {}
config_dict: Dict[str, Any] = {}
loader_dict: Dict[str, Any] = {}
optimizer_dict: Dict[str, Any] = {}
scheduler_dict: Dict[str, Any] = {}
loss_dict: Dict[str, Any] = {}
train_dict: Dict[str, Any] = {}


[docs]def register_base(mapping: Dict[str, Any], key: str, module: Any = None) -> Union[None, Callable]: r"""Base function for registering a module in GraphGym. Args: mapping (dict): Python dictionary to register the module. hosting all the registered modules key (string): The name of the module. module (any, optional): The module. If set to :obj:`None`, will return a decorator to register a module. """ if module is not None: if key in mapping: raise KeyError(f"Module with '{key}' already defined") mapping[key] = module return # Other-wise, use it as a decorator: def wrapper(module): register_base(mapping, key, module) return module return wrapper
[docs]def register_act(key: str, module: Any = None): r"""Registers an activation function in GraphGym.""" return register_base(act_dict, key, module)
[docs]def register_node_encoder(key: str, module: Any = None): r"""Registers a node feature encoder in GraphGym.""" return register_base(node_encoder_dict, key, module)
[docs]def register_edge_encoder(key: str, module: Any = None): r"""Registers an edge feature encoder in GraphGym.""" return register_base(edge_encoder_dict, key, module)
[docs]def register_stage(key: str, module: Any = None): r"""Registers a customized GNN stage in GraphGym.""" return register_base(stage_dict, key, module)
[docs]def register_head(key: str, module: Any = None): r"""Registers a GNN prediction head in GraphGym.""" return register_base(head_dict, key, module)
[docs]def register_layer(key: str, module: Any = None): r"""Registers a GNN layer in GraphGym.""" return register_base(layer_dict, key, module)
[docs]def register_pooling(key: str, module: Any = None): r"""Registers a GNN global pooling/readout layer in GraphGym.""" return register_base(pooling_dict, key, module)
[docs]def register_network(key: str, module: Any = None): r"""Registers a GNN model in GraphGym.""" return register_base(network_dict, key, module)
[docs]def register_config(key: str, module: Any = None): r"""Registers a configuration group in GraphGym.""" return register_base(config_dict, key, module)
[docs]def register_loader(key: str, module: Any = None): r"""Registers a data loader in GraphGym.""" return register_base(loader_dict, key, module)
[docs]def register_optimizer(key: str, module: Any = None): r"""Registers an optimizer in GraphGym.""" return register_base(optimizer_dict, key, module)
[docs]def register_scheduler(key: str, module: Any = None): r"""Registers a learning rate scheduler in GraphGym.""" return register_base(scheduler_dict, key, module)
[docs]def register_loss(key: str, module: Any = None): r"""Registers a loss function in GraphGym.""" return register_base(loss_dict, key, module)
[docs]def register_train(key: str, module: Any = None): r"""Registers a training function in GraphGym.""" return register_base(train_dict, key, module)