Source code for torch_geometric.graphgym.optimizer

from dataclasses import dataclass, field

import torch.optim as optim
import torch_geometric.graphgym.register as register


@dataclass
class OptimizerConfig:
    # optimizer: sgd, adam
    optimizer: str = 'adam'
    # Base learning rate
    base_lr: float = 0.01
    # L2 regularization
    weight_decay: float = 5e-4
    # SGD momentum
    momentum: float = 0.9


[docs]def create_optimizer(params, optimizer_config: OptimizerConfig): r""" Create optimizer for the model Args: params: PyTorch model parameters Returns: PyTorch optimizer """ params = filter(lambda p: p.requires_grad, params) # Try to load customized optimizer for func in register.optimizer_dict.values(): optimizer = func(params, optimizer_config) if optimizer is not None: return optimizer if optimizer_config.optimizer == 'adam': optimizer = optim.Adam(params, lr=optimizer_config.base_lr, weight_decay=optimizer_config.weight_decay) elif optimizer_config.optimizer == 'sgd': optimizer = optim.SGD(params, lr=optimizer_config.base_lr, momentum=optimizer_config.momentum, weight_decay=optimizer_config.weight_decay) else: raise ValueError('Optimizer {} not supported'.format( optimizer_config.optimizer)) return optimizer
@dataclass class SchedulerConfig: # scheduler: none, steps, cos scheduler: str = 'cos' # Steps for 'steps' policy (in epochs) steps: list = field(default_factory=[30, 60, 90]) # Learning rate multiplier for 'steps' policy lr_decay: float = 0.1 # Maximal number of epochs max_epoch: int = 200
[docs]def create_scheduler(optimizer, scheduler_config: SchedulerConfig): r""" Create learning rate scheduler for the optimizer Args: optimizer: PyTorch optimizer Returns: PyTorch scheduler """ # Try to load customized scheduler for func in register.scheduler_dict.values(): scheduler = func(optimizer, scheduler_config) if scheduler is not None: return scheduler if scheduler_config.scheduler == 'none': scheduler = optim.lr_scheduler.StepLR( optimizer, step_size=scheduler_config.max_epoch + 1) elif scheduler_config.scheduler == 'step': scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=scheduler_config.steps, gamma=scheduler_config.lr_decay) elif scheduler_config.scheduler == 'cos': scheduler = \ optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=scheduler_config.max_epoch) else: raise ValueError('Scheduler {} not supported'.format( scheduler_config.scheduler)) return scheduler