Source code for torch_geometric.graphgym.checkpoint

from torch_geometric.graphgym.config import cfg

import torch
import os
import logging


def get_ckpt_dir():
    return '{}/ckpt'.format(cfg.run_dir)


def get_all_epoch():
    d = get_ckpt_dir()
    names = os.listdir(d) if os.path.exists(d) else []
    if len(names) == 0:
        return [0]
    epochs = [int(name.split('.')[0]) for name in names]
    return epochs


def get_last_epoch():
    return max(get_all_epoch())


[docs]def load_ckpt(model, optimizer=None, scheduler=None): r''' Load latest model checkpoint Args: model (torch.nn.Module): The model that will be loaded optimizer (torch.optim, optional): The optimizer that will be loaded scheduler (torch.optim, optional): The schduler that will be loaded Returns: Epoch count after loading the model ''' if cfg.train.epoch_resume < 0: epoch_resume = get_last_epoch() else: epoch_resume = cfg.train.epoch_resume ckpt_name = '{}/{}.ckpt'.format(get_ckpt_dir(), epoch_resume) if not os.path.isfile(ckpt_name): return 0 ckpt = torch.load(ckpt_name) epoch = ckpt['epoch'] model.load_state_dict(ckpt['model_state']) if optimizer is not None: optimizer.load_state_dict(ckpt['optimizer_state']) if scheduler is not None: scheduler.load_state_dict(ckpt['scheduler_state']) return epoch + 1
[docs]def save_ckpt(model, optimizer, scheduler, epoch): r''' Save model checkpoint at given epoch Args: model (torch.nn.Module): The model that will be saved optimizer (torch.optim): The optimizer that will be saved scheduler (torch.optim): The schduler that will be saved epoch (int): The epoch when the model is saved ''' ckpt = { 'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'scheduler_state': scheduler.state_dict() } os.makedirs(get_ckpt_dir(), exist_ok=True) ckpt_name = '{}/{}.ckpt'.format(get_ckpt_dir(), epoch) torch.save(ckpt, ckpt_name) logging.info('Check point saved: {}'.format(ckpt_name))
[docs]def clean_ckpt(): r''' Only keep the latest model checkpoint, remove all the older checkpoints ''' epochs = get_all_epoch() epoch_last = max(epochs) for epoch in epochs: if epoch != epoch_last: ckpt_name = '{}/{}.ckpt'.format(get_ckpt_dir(), epoch) os.remove(ckpt_name)