torch_geometric.data.lightning.LightningDataset

class LightningDataset(train_dataset: Dataset, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, pred_dataset: Optional[Dataset] = None, **kwargs)[source]

Bases: LightningDataModule

Converts a set of Dataset objects into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU graph-level training via PyTorch Lightning. LightningDataset will take care of providing mini-batches via DataLoader.

Note

Currently only the pytorch_lightning.strategies.SingleDeviceStrategy and pytorch_lightning.strategies.DDPSpawnStrategy training strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)
Parameters