from torch_geometric.graphgym.config import cfg
import torch
import os
import logging
def get_ckpt_dir():
return '{}/ckpt'.format(cfg.run_dir)
def get_all_epoch():
d = get_ckpt_dir()
names = os.listdir(d) if os.path.exists(d) else []
if len(names) == 0:
return [0]
epochs = [int(name.split('.')[0]) for name in names]
return epochs
def get_last_epoch():
return max(get_all_epoch())
[docs]def load_ckpt(model, optimizer=None, scheduler=None):
r'''
Load latest model checkpoint
Args:
model (torch.nn.Module): The model that will be loaded
optimizer (torch.optim, optional): The optimizer that will be loaded
scheduler (torch.optim, optional): The schduler that will be loaded
Returns:
Epoch count after loading the model
'''
if cfg.train.epoch_resume < 0:
epoch_resume = get_last_epoch()
else:
epoch_resume = cfg.train.epoch_resume
ckpt_name = '{}/{}.ckpt'.format(get_ckpt_dir(), epoch_resume)
if not os.path.isfile(ckpt_name):
return 0
ckpt = torch.load(ckpt_name)
epoch = ckpt['epoch']
model.load_state_dict(ckpt['model_state'])
if optimizer is not None:
optimizer.load_state_dict(ckpt['optimizer_state'])
if scheduler is not None:
scheduler.load_state_dict(ckpt['scheduler_state'])
return epoch + 1
[docs]def save_ckpt(model, optimizer, scheduler, epoch):
r'''
Save model checkpoint at given epoch
Args:
model (torch.nn.Module): The model that will be saved
optimizer (torch.optim): The optimizer that will be saved
scheduler (torch.optim): The schduler that will be saved
epoch (int): The epoch when the model is saved
'''
ckpt = {
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict()
}
os.makedirs(get_ckpt_dir(), exist_ok=True)
ckpt_name = '{}/{}.ckpt'.format(get_ckpt_dir(), epoch)
torch.save(ckpt, ckpt_name)
logging.info('Check point saved: {}'.format(ckpt_name))
[docs]def clean_ckpt():
r'''
Only keep the latest model checkpoint, remove all the older checkpoints
'''
epochs = get_all_epoch()
epoch_last = max(epochs)
for epoch in epochs:
if epoch != epoch_last:
ckpt_name = '{}/{}.ckpt'.format(get_ckpt_dir(), epoch)
os.remove(ckpt_name)