Source code for torch_geometric.graphgym.optim

from dataclasses import dataclass, field
from typing import Any, Iterator, List, Optional

from torch.nn import Parameter
from torch.optim import SGD, Adam, Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR, StepLR

import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import from_config


@dataclass
class OptimizerConfig:
    optimizer: str = 'adam'  # ['sgd', 'adam']
    base_lr: float = 0.01
    weight_decay: float = 5e-4
    momentum: float = 0.9  # 'sgd' policy


@register.register_optimizer('adam')
def adam_optimizer(params: Iterator[Parameter], base_lr: float,
                   weight_decay: float) -> Adam:
    return Adam(params, lr=base_lr, weight_decay=weight_decay)


@register.register_optimizer('sgd')
def sgd_optimizer(params: Iterator[Parameter], base_lr: float, momentum: float,
                  weight_decay: float) -> SGD:
    return SGD(params, lr=base_lr, momentum=momentum,
               weight_decay=weight_decay)


[docs]def create_optimizer(params: Iterator[Parameter], cfg: Any) -> Any: r"""Creates a config-driven optimizer.""" params = filter(lambda p: p.requires_grad, params) func = register.optimizer_dict.get(cfg.optimizer, None) if func is not None: return from_config(func)(params, cfg=cfg) raise ValueError(f"Optimizer '{cfg.optimizer}' not supported")
@dataclass class SchedulerConfig: scheduler: Optional[str] = 'cos' # [None, 'steps', 'cos'] steps: List[int] = field(default_factory=[30, 60, 90]) # 'steps' policy lr_decay: float = 0.1 # 'steps' policy max_epoch: int = 200 @register.register_scheduler(None) @register.register_scheduler('none') def none_scheduler(optimizer: Optimizer, max_epoch: int) -> StepLR: return StepLR(optimizer, step_size=max_epoch + 1) @register.register_scheduler('step') def step_scheduler(optimizer: Optimizer, steps: List[int], lr_decay: float) -> MultiStepLR: return MultiStepLR(optimizer, milestones=steps, gamma=lr_decay) @register.register_scheduler('cos') def cos_scheduler(optimizer: Optimizer, max_epoch: int) -> CosineAnnealingLR: return CosineAnnealingLR(optimizer, T_max=max_epoch)
[docs]def create_scheduler(optimizer: Optimizer, cfg: Any) -> Any: r"""Creates a config-driven learning rate scheduler.""" func = register.scheduler_dict.get(cfg.scheduler, None) if func is not None: return from_config(func)(optimizer, cfg=cfg) raise ValueError(f"Scheduler '{cfg.scheduler}' not supported")