from dataclasses import dataclass, field
import torch.optim as optim
import torch_geometric.graphgym.register as register
@dataclass
class OptimizerConfig:
# optimizer: sgd, adam
optimizer: str = 'adam'
# Base learning rate
base_lr: float = 0.01
# L2 regularization
weight_decay: float = 5e-4
# SGD momentum
momentum: float = 0.9
[docs]def create_optimizer(params, optimizer_config: OptimizerConfig):
r"""
Create optimizer for the model
Args:
params: PyTorch model parameters
Returns: PyTorch optimizer
"""
params = filter(lambda p: p.requires_grad, params)
# Try to load customized optimizer
for func in register.optimizer_dict.values():
optimizer = func(params, optimizer_config)
if optimizer is not None:
return optimizer
if optimizer_config.optimizer == 'adam':
optimizer = optim.Adam(params, lr=optimizer_config.base_lr,
weight_decay=optimizer_config.weight_decay)
elif optimizer_config.optimizer == 'sgd':
optimizer = optim.SGD(params, lr=optimizer_config.base_lr,
momentum=optimizer_config.momentum,
weight_decay=optimizer_config.weight_decay)
else:
raise ValueError('Optimizer {} not supported'.format(
optimizer_config.optimizer))
return optimizer
@dataclass
class SchedulerConfig:
# scheduler: none, steps, cos
scheduler: str = 'cos'
# Steps for 'steps' policy (in epochs)
steps: list = field(default_factory=[30, 60, 90])
# Learning rate multiplier for 'steps' policy
lr_decay: float = 0.1
# Maximal number of epochs
max_epoch: int = 200
[docs]def create_scheduler(optimizer, scheduler_config: SchedulerConfig):
r"""
Create learning rate scheduler for the optimizer
Args:
optimizer: PyTorch optimizer
Returns: PyTorch scheduler
"""
# Try to load customized scheduler
for func in register.scheduler_dict.values():
scheduler = func(optimizer, scheduler_config)
if scheduler is not None:
return scheduler
if scheduler_config.scheduler == 'none':
scheduler = optim.lr_scheduler.StepLR(
optimizer, step_size=scheduler_config.max_epoch + 1)
elif scheduler_config.scheduler == 'step':
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer, milestones=scheduler_config.steps,
gamma=scheduler_config.lr_decay)
elif scheduler_config.scheduler == 'cos':
scheduler = \
optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=scheduler_config.max_epoch)
else:
raise ValueError('Scheduler {} not supported'.format(
scheduler_config.scheduler))
return scheduler