torch_geometric.data.lightning.LightningNodeData

class LightningNodeData(data: Union[Data, HeteroData], input_train_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_train_time: Optional[Tensor] = None, input_val_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_val_time: Optional[Tensor] = None, input_test_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_test_time: Optional[Tensor] = None, input_pred_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_pred_time: Optional[Tensor] = None, loader: str = 'neighbor', node_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs)[source]

Bases: LightningData

Converts a Data or HeteroData object into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU node-level training via PyTorch Lightning. LightningDataset will take care of providing mini-batches via NeighborLoader.

Note

Currently only the pytorch_lightning.strategies.SingleDeviceStrategy and pytorch_lightning.strategies.DDPSpawnStrategy training strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
                     devices=4)
trainer.fit(model, datamodule)
Parameters
  • data (Data or HeteroData) – The Data or HeteroData graph object.

  • input_train_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of training nodes. If not given, will try to automatically infer them from the data object by searching for train_mask, train_idx, or train_index attributes. (default: None)

  • input_train_time (torch.Tensor, optional) – The timestamp of training nodes. (default: None)

  • input_val_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of validation nodes. If not given, will try to automatically infer them from the data object by searching for val_mask, valid_mask, val_idx, valid_idx, val_index, or valid_index attributes. (default: None)

  • input_val_time (torch.Tensor, optional) – The timestamp of validation edges. (default: None)

  • input_test_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of test nodes. If not given, will try to automatically infer them from the data object by searching for test_mask, test_idx, or test_index attributes. (default: None)

  • input_test_time (torch.Tensor, optional) – The timestamp of test nodes. (default: None)

  • input_pred_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of prediction nodes. If not given, will try to automatically infer them from the data object by searching for pred_mask, pred_idx, or pred_index attributes. (default: None)

  • input_pred_time (torch.Tensor, optional) – The timestamp of prediction nodes. (default: None)

  • loader (str) – The scalability technique to use ("full", "neighbor"). (default: "neighbor")

  • node_sampler (BaseSampler, optional) – A custom sampler object to generate mini-batches. If set, will ignore the loader option. (default: None)

  • eval_loader_kwargs (Dict[str, Any], optional) – Custom keyword arguments that override the torch_geometric.loader.NeighborLoader configuration during evaluation. (default: None)

  • **kwargs (optional) – Additional arguments of torch_geometric.loader.NeighborLoader.