torch_geometric.data.lightning.LightningDataset
- class LightningDataset(train_dataset: Dataset, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, pred_dataset: Optional[Dataset] = None, **kwargs: Any)[source]
Bases:
LightningDataModule
Converts a set of
Dataset
objects into apytorch_lightning.LightningDataModule
variant. It can then be automatically used as adatamodule
for multi-GPU graph-level training via PyTorch Lightning.LightningDataset
will take care of providing mini-batches viaDataLoader
.Note
Currently only the
pytorch_lightning.strategies.SingleDeviceStrategy
andpytorch_lightning.strategies.DDPStrategy
training strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule)
- Parameters:
train_dataset (Dataset) – The training dataset.
val_dataset (Dataset, optional) – The validation dataset. (default:
None
)test_dataset (Dataset, optional) – The test dataset. (default:
None
)pred_dataset (Dataset, optional) – The prediction dataset. (default:
None
)**kwargs (optional) – Additional arguments of
torch_geometric.loader.DataLoader
.