Source code for torch_geometric.graphgym.logger

import logging
import math
import os
import sys
import time
from typing import Any, Dict, Optional

import torch

from torch_geometric.graphgym import register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.imports import Callback, pl
from torch_geometric.graphgym.utils.device import get_current_gpu_usage
from torch_geometric.graphgym.utils.io import dict_to_json, dict_to_tb


[docs]def set_printing(): """Set up printing options.""" logging.root.handlers = [] logging_cfg = {'level': logging.INFO, 'format': '%(message)s'} os.makedirs(cfg.run_dir, exist_ok=True) h_file = logging.FileHandler(f'{cfg.run_dir}/logging.log') h_stdout = logging.StreamHandler(sys.stdout) if cfg.print == 'file': logging_cfg['handlers'] = [h_file] elif cfg.print == 'stdout': logging_cfg['handlers'] = [h_stdout] elif cfg.print == 'both': logging_cfg['handlers'] = [h_file, h_stdout] else: raise ValueError('Print option not supported') logging.basicConfig(**logging_cfg)
class Logger: def __init__(self, name='train', task_type=None): self.name = name self.task_type = task_type self._epoch_total = cfg.optim.max_epoch self._time_total = 0 # won't be reset self.out_dir = f'{cfg.run_dir}/{name}' os.makedirs(self.out_dir, exist_ok=True) if cfg.tensorboard_each_run: from tensorboardX import SummaryWriter self.tb_writer = SummaryWriter(self.out_dir) self.reset() def __getitem__(self, key): return getattr(self, key, None) def __setitem__(self, key, value): setattr(self, key, value) def reset(self): self._iter = 0 self._size_current = 0 self._loss = 0 self._lr = 0 self._params = 0 self._time_used = 0 self._true = [] self._pred = [] self._custom_stats = {} # basic properties def basic(self): stats = { 'loss': round(self._loss / self._size_current, cfg.round), 'lr': round(self._lr, cfg.round), 'params': self._params, 'time_iter': round(self.time_iter(), cfg.round), } gpu_memory = get_current_gpu_usage() if gpu_memory > 0: stats['gpu_memory'] = gpu_memory return stats # customized input properties def custom(self): if len(self._custom_stats) == 0: return {} out = {} for key, val in self._custom_stats.items(): out[key] = val / self._size_current return out def _get_pred_int(self, pred_score): if len(pred_score.shape) == 1 or pred_score.shape[1] == 1: return (pred_score > cfg.model.thresh).long() else: return pred_score.max(dim=1)[1] # task properties def classification_binary(self): from sklearn.metrics import ( accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, ) true, pred_score = torch.cat(self._true), torch.cat(self._pred) pred_int = self._get_pred_int(pred_score) try: r_a_score = roc_auc_score(true, pred_score) except ValueError: r_a_score = 0.0 return { 'accuracy': round(accuracy_score(true, pred_int), cfg.round), 'precision': round(precision_score(true, pred_int), cfg.round), 'recall': round(recall_score(true, pred_int), cfg.round), 'f1': round(f1_score(true, pred_int), cfg.round), 'auc': round(r_a_score, cfg.round), } def classification_multi(self): from sklearn.metrics import accuracy_score true, pred_score = torch.cat(self._true), torch.cat(self._pred) pred_int = self._get_pred_int(pred_score) return {'accuracy': round(accuracy_score(true, pred_int), cfg.round)} def regression(self): from sklearn.metrics import mean_absolute_error, mean_squared_error true, pred = torch.cat(self._true), torch.cat(self._pred) return { 'mae': float(round(mean_absolute_error(true, pred), cfg.round)), 'mse': float(round(mean_squared_error(true, pred), cfg.round)), 'rmse': float(round(math.sqrt(mean_squared_error(true, pred)), cfg.round)) } def time_iter(self): return self._time_used / self._iter def eta(self, epoch_current): epoch_current += 1 # since counter starts from 0 time_per_epoch = self._time_total / epoch_current return time_per_epoch * (self._epoch_total - epoch_current) def update_stats(self, true, pred, loss, lr, time_used, params, **kwargs): assert true.shape[0] == pred.shape[0] self._iter += 1 self._true.append(true) self._pred.append(pred) batch_size = true.shape[0] self._size_current += batch_size self._loss += loss * batch_size self._lr = lr self._params = params self._time_used += time_used self._time_total += time_used for key, val in kwargs.items(): if key not in self._custom_stats: self._custom_stats[key] = val * batch_size else: self._custom_stats[key] += val * batch_size def write_iter(self): raise NotImplementedError def write_epoch(self, cur_epoch): basic_stats = self.basic() # Try to load customized metrics task_stats = {} for custom_metric in cfg.custom_metrics: func = register.metric_dict.get(custom_metric) if not func: raise ValueError( f'Unknown custom metric function name: {custom_metric}') custom_metric_score = func(self._true, self._pred, self.task_type) task_stats[custom_metric] = custom_metric_score if not task_stats: # use default metrics if no matching custom metric if self.task_type == 'regression': task_stats = self.regression() elif self.task_type == 'classification_binary': task_stats = self.classification_binary() elif self.task_type == 'classification_multi': task_stats = self.classification_multi() else: raise ValueError('Task has to be regression or classification') epoch_stats = {'epoch': cur_epoch} eta_stats = {'eta': round(self.eta(cur_epoch), cfg.round)} custom_stats = self.custom() if self.name == 'train': stats = { **epoch_stats, **eta_stats, **basic_stats, **task_stats, **custom_stats } else: stats = { **epoch_stats, **basic_stats, **task_stats, **custom_stats } # print logging.info(f'{self.name}: {stats}') # json dict_to_json(stats, f'{self.out_dir}/stats.json') # tensorboard if cfg.tensorboard_each_run: dict_to_tb(stats, self.tb_writer, cur_epoch) self.reset() def close(self): if cfg.tensorboard_each_run: self.tb_writer.close() def infer_task(): num_label = cfg.share.dim_out if cfg.dataset.task_type == 'classification': if num_label <= 2: task_type = 'classification_binary' else: task_type = 'classification_multi' else: task_type = cfg.dataset.task_type return task_type
[docs]def create_logger(): r"""Create logger for the experiment.""" loggers = [] names = ['train', 'val', 'test'] for i, dataset in enumerate(range(cfg.share.num_splits)): loggers.append(Logger(name=names[i], task_type=infer_task())) return loggers
class LoggerCallback(Callback): def __init__(self): self._logger = create_logger() self._train_epoch_start_time = None self._val_epoch_start_time = None self._test_epoch_start_time = None @property def train_logger(self) -> Any: return self._logger[0] @property def val_logger(self) -> Any: return self._logger[1] @property def test_logger(self) -> Any: return self._logger[2] def close(self): for logger in self._logger: logger.close() def _get_stats( self, epoch_start_time: int, outputs: Dict[str, Any], trainer: 'pl.Trainer', ) -> Dict: return dict( true=outputs['true'].detach().cpu(), pred=outputs['pred_score'].detach().cpu(), loss=float(outputs['loss']), lr=trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0], time_used=time.time() - epoch_start_time, params=cfg.params, ) def on_train_epoch_start( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self._train_epoch_start_time = time.time() def on_validation_epoch_start( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self._val_epoch_start_time = time.time() def on_test_epoch_start( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self._test_epoch_start_time = time.time() def on_train_batch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Dict[str, Any], batch: Any, batch_idx: int, unused: int = 0, ): stats = self._get_stats(self._train_epoch_start_time, outputs, trainer) self.train_logger.update_stats(**stats) def on_validation_batch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Optional[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int = 0, ): stats = self._get_stats(self._val_epoch_start_time, outputs, trainer) self.val_logger.update_stats(**stats) def on_test_batch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: Optional[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int = 0, ): stats = self._get_stats(self._test_epoch_start_time, outputs, trainer) self.test_logger.update_stats(**stats) def on_train_epoch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self.train_logger.write_epoch(trainer.current_epoch) def on_validation_epoch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self.val_logger.write_epoch(trainer.current_epoch) def on_test_epoch_end( self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', ): self.test_logger.write_epoch(trainer.current_epoch) def on_fit_end(self, trainer, pl_module): self.close()