Source code for torch_geometric.graphgym.train

import warnings
from typing import Any, Dict, Optional

import torch
from torch.utils.data import DataLoader

from torch_geometric.data.lightning.datamodule import LightningDataModule
from torch_geometric.graphgym import create_loader
from torch_geometric.graphgym.checkpoint import get_ckpt_dir
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.imports import pl
from torch_geometric.graphgym.logger import LoggerCallback
from torch_geometric.graphgym.model_builder import GraphGymModule


class GraphGymDataModule(LightningDataModule):
    r"""A :class:`pytorch_lightning.LightningDataModule` for handling data
    loading routines in GraphGym.

    This class provides data loaders for training, validation, and testing, and
    can be accessed through the :meth:`train_dataloader`,
    :meth:`val_dataloader`, and :meth:`test_dataloader` methods, respectively.
    """
    def __init__(self):
        self.loaders = create_loader()
        super().__init__(has_val=True, has_test=True)

    def train_dataloader(self) -> DataLoader:
        return self.loaders[0]

    def val_dataloader(self) -> DataLoader:
        # better way would be to test after fit.
        # First call trainer.fit(...) then trainer.test(...)
        return self.loaders[1]

    def test_dataloader(self) -> DataLoader:
        return self.loaders[2]


[docs]def train( model: GraphGymModule, datamodule: GraphGymDataModule, logger: bool = True, trainer_config: Optional[Dict[str, Any]] = None, ): r"""Trains a GraphGym model using PyTorch Lightning. Args: model (GraphGymModule): The GraphGym model. datamodule (GraphGymDataModule): The GraphGym data module. logger (bool, optional): Whether to enable logging during training. (default: :obj:`True`) trainer_config (dict, optional): Additional trainer configuration. """ warnings.filterwarnings('ignore', '.*use `CSVLogger` as the default.*') callbacks = [] if logger: callbacks.append(LoggerCallback()) if cfg.train.enable_ckpt: ckpt_cbk = pl.callbacks.ModelCheckpoint(dirpath=get_ckpt_dir()) callbacks.append(ckpt_cbk) trainer_config = trainer_config or {} trainer = pl.Trainer( **trainer_config, enable_checkpointing=cfg.train.enable_ckpt, callbacks=callbacks, default_root_dir=cfg.out_dir, max_epochs=cfg.optim.max_epoch, accelerator=cfg.accelerator, devices='auto' if not torch.cuda.is_available() else cfg.devices, ) trainer.fit(model, datamodule=datamodule) trainer.test(model, datamodule=datamodule)