Source code for torch_geometric.data.lightning.datamodule

import copy
import inspect
import warnings
from typing import Any, Dict, Optional, Tuple, Type, Union

import torch

from torch_geometric.data import Data, Dataset, HeteroData
from torch_geometric.loader import DataLoader, LinkLoader, NodeLoader
from torch_geometric.sampler import BaseSampler, NeighborSampler
from torch_geometric.typing import InputEdges, InputNodes, OptTensor

try:
    from pytorch_lightning import LightningDataModule as PLLightningDataModule
    no_pytorch_lightning = False
except ImportError:
    PLLightningDataModule = object  # type: ignore
    no_pytorch_lightning = True


class LightningDataModule(PLLightningDataModule):
    def __init__(self, has_val: bool, has_test: bool, **kwargs: Any) -> None:
        super().__init__()

        if no_pytorch_lightning:
            raise ModuleNotFoundError(
                "No module named 'pytorch_lightning' found on this machine. "
                "Run 'pip install pytorch_lightning' to install the library.")

        if not has_val:
            self.val_dataloader = None  # type: ignore

        if not has_test:
            self.test_dataloader = None  # type: ignore

        kwargs.setdefault('batch_size', 1)
        kwargs.setdefault('num_workers', 0)
        kwargs.setdefault('pin_memory', True)
        kwargs.setdefault('persistent_workers',
                          kwargs.get('num_workers', 0) > 0)

        if 'shuffle' in kwargs:
            warnings.warn(f"The 'shuffle={kwargs['shuffle']}' option is "
                          f"ignored in '{self.__class__.__name__}'. Remove it "
                          f"from the argument list to disable this warning")
            del kwargs['shuffle']

        self.kwargs = kwargs

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({kwargs_repr(**self.kwargs)})'


class LightningData(LightningDataModule):
    def __init__(
        self,
        data: Union[Data, HeteroData],
        has_val: bool,
        has_test: bool,
        loader: str = 'neighbor',
        graph_sampler: Optional[BaseSampler] = None,
        eval_loader_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> None:
        kwargs.setdefault('batch_size', 1)
        kwargs.setdefault('num_workers', 0)

        if graph_sampler is not None:
            loader = 'custom'

        # For full-batch training, we use reasonable defaults for a lot of
        # data-loading options:
        if loader not in ['full', 'neighbor', 'link_neighbor', 'custom']:
            raise ValueError(f"Undefined 'loader' option (got '{loader}')")

        if loader == 'full' and kwargs['batch_size'] != 1:
            warnings.warn(f"Re-setting 'batch_size' to 1 in "
                          f"'{self.__class__.__name__}' for loader='full' "
                          f"(got '{kwargs['batch_size']}')")
            kwargs['batch_size'] = 1

        if loader == 'full' and kwargs['num_workers'] != 0:
            warnings.warn(f"Re-setting 'num_workers' to 0 in "
                          f"'{self.__class__.__name__}' for loader='full' "
                          f"(got '{kwargs['num_workers']}')")
            kwargs['num_workers'] = 0

        if loader == 'full' and kwargs.get('sampler') is not None:
            warnings.warn("'sampler' option is not supported for "
                          "loader='full'")
            kwargs.pop('sampler', None)

        if loader == 'full' and kwargs.get('batch_sampler') is not None:
            warnings.warn("'batch_sampler' option is not supported for "
                          "loader='full'")
            kwargs.pop('batch_sampler', None)

        super().__init__(has_val, has_test, **kwargs)

        if loader == 'full':
            if kwargs.get('pin_memory', False):
                warnings.warn(f"Re-setting 'pin_memory' to 'False' in "
                              f"'{self.__class__.__name__}' for loader='full' "
                              f"(got 'True')")
            self.kwargs['pin_memory'] = False

        self.data = data
        self.loader = loader

        # Determine sampler and loader arguments ##############################

        if loader in ['neighbor', 'link_neighbor']:

            # Define a new `NeighborSampler` to be re-used across data loaders:
            sampler_kwargs, self.loader_kwargs = split_kwargs(
                self.kwargs,
                NeighborSampler,
            )
            sampler_kwargs.setdefault('share_memory',
                                      self.kwargs['num_workers'] > 0)
            self.graph_sampler: BaseSampler = NeighborSampler(
                data, **sampler_kwargs)

        elif graph_sampler is not None:
            sampler_kwargs, self.loader_kwargs = split_kwargs(
                self.kwargs,
                graph_sampler.__class__,
            )
            if len(sampler_kwargs) > 0:
                warnings.warn(f"Ignoring the arguments "
                              f"{list(sampler_kwargs.keys())} in "
                              f"'{self.__class__.__name__}' since a custom "
                              f"'graph_sampler' was passed")
            self.graph_sampler = graph_sampler

        else:
            assert loader == 'full'
            self.loader_kwargs = self.kwargs

        # Determine validation sampler and loader arguments ###################

        self.eval_loader_kwargs = copy.copy(self.loader_kwargs)
        if eval_loader_kwargs is not None:
            # If the user wants to override certain values during evaluation,
            # we shallow-copy the graph sampler and update its attributes.
            if hasattr(self, 'graph_sampler'):
                self.eval_graph_sampler = copy.copy(self.graph_sampler)

                eval_sampler_kwargs, eval_loader_kwargs = split_kwargs(
                    eval_loader_kwargs,
                    self.graph_sampler.__class__,
                )
                for key, value in eval_sampler_kwargs.items():
                    setattr(self.eval_graph_sampler, key, value)

            self.eval_loader_kwargs.update(eval_loader_kwargs)

        elif hasattr(self, 'graph_sampler'):
            self.eval_graph_sampler = self.graph_sampler

        self.eval_loader_kwargs.pop('sampler', None)
        self.eval_loader_kwargs.pop('batch_sampler', None)

        if 'batch_sampler' in self.loader_kwargs:
            self.loader_kwargs.pop('batch_size', None)

    @property
    def train_shuffle(self) -> bool:
        shuffle = self.loader_kwargs.get('sampler', None) is None
        shuffle &= self.loader_kwargs.get('batch_sampler', None) is None
        return shuffle

    def prepare_data(self) -> None:
        if self.loader == 'full':
            assert self.trainer is not None
            try:
                num_devices = self.trainer.num_devices
            except AttributeError:
                # PyTorch Lightning < 1.6 backward compatibility:
                num_devices = self.trainer.num_processes  # type: ignore
                num_gpus = self.trainer.num_gpus  # type: ignore
                num_devices = max(num_devices, num_gpus)

            if num_devices > 1:
                raise ValueError(
                    f"'{self.__class__.__name__}' with loader='full' requires "
                    f"training on a single device")
        super().prepare_data()

    def full_dataloader(self, **kwargs: Any) -> torch.utils.data.DataLoader:
        warnings.filterwarnings('ignore', '.*does not have many workers.*')
        warnings.filterwarnings('ignore', '.*data loading bottlenecks.*')

        return torch.utils.data.DataLoader(
            [self.data],  # type: ignore
            collate_fn=lambda xs: xs[0],
            **kwargs,
        )

    def __repr__(self) -> str:
        kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs)
        return f'{self.__class__.__name__}({kwargs})'


[docs]class LightningDataset(LightningDataModule): r"""Converts a set of :class:`~torch_geometric.data.Dataset` objects into a :class:`pytorch_lightning.LightningDataModule` variant. It can then be automatically used as a :obj:`datamodule` for multi-GPU graph-level training via :lightning:`null` `PyTorch Lightning <https://www.pytorchlightning.ai>`__. :class:`LightningDataset` will take care of providing mini-batches via :class:`~torch_geometric.loader.DataLoader`. .. note:: Currently only the :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and :class:`pytorch_lightning.strategies.DDPStrategy` training strategies of :lightning:`null` `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/en/latest/guides/ speed.html>`__ are supported in order to correctly share data across all devices/processes: .. code-block:: python import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule) Args: train_dataset (Dataset): The training dataset. val_dataset (Dataset, optional): The validation dataset. (default: :obj:`None`) test_dataset (Dataset, optional): The test dataset. (default: :obj:`None`) pred_dataset (Dataset, optional): The prediction dataset. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.loader.DataLoader`. """ def __init__( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, pred_dataset: Optional[Dataset] = None, **kwargs: Any, ) -> None: super().__init__( has_val=val_dataset is not None, has_test=test_dataset is not None, **kwargs, ) self.train_dataset = train_dataset self.val_dataset = val_dataset self.test_dataset = test_dataset self.pred_dataset = pred_dataset def dataloader(self, dataset: Dataset, **kwargs: Any) -> DataLoader: return DataLoader(dataset, **kwargs) def train_dataloader(self) -> DataLoader: from torch.utils.data import IterableDataset shuffle = not isinstance(self.train_dataset, IterableDataset) shuffle &= self.kwargs.get('sampler', None) is None shuffle &= self.kwargs.get('batch_sampler', None) is None return self.dataloader( self.train_dataset, shuffle=shuffle, **self.kwargs, ) def val_dataloader(self) -> DataLoader: assert self.val_dataset is not None kwargs = copy.copy(self.kwargs) kwargs.pop('sampler', None) kwargs.pop('batch_sampler', None) return self.dataloader(self.val_dataset, shuffle=False, **kwargs) def test_dataloader(self) -> DataLoader: assert self.test_dataset is not None kwargs = copy.copy(self.kwargs) kwargs.pop('sampler', None) kwargs.pop('batch_sampler', None) return self.dataloader(self.test_dataset, shuffle=False, **kwargs) def predict_dataloader(self) -> DataLoader: assert self.pred_dataset is not None kwargs = copy.copy(self.kwargs) kwargs.pop('sampler', None) kwargs.pop('batch_sampler', None) return self.dataloader(self.pred_dataset, shuffle=False, **kwargs) def __repr__(self) -> str: kwargs = kwargs_repr( train_dataset=self.train_dataset, val_dataset=self.val_dataset, test_dataset=self.test_dataset, pred_dataset=self.pred_dataset, **self.kwargs, ) return f'{self.__class__.__name__}({kwargs})'
[docs]class LightningNodeData(LightningData): r"""Converts a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into a :class:`pytorch_lightning.LightningDataModule` variant. It can then be automatically used as a :obj:`datamodule` for multi-GPU node-level training via :lightning:`null` `PyTorch Lightning <https://www.pytorchlightning.ai>`__. :class:`LightningDataset` will take care of providing mini-batches via :class:`~torch_geometric.loader.NeighborLoader`. .. note:: Currently only the :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and :class:`pytorch_lightning.strategies.DDPStrategy` training strategies of :lightning:`null` `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/en/latest/guides/ speed.html>`__ are supported in order to correctly share data across all devices/processes: .. code-block:: python import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule) Args: data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. input_train_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of training nodes. If not given, will try to automatically infer them from the :obj:`data` object by searching for :obj:`train_mask`, :obj:`train_idx`, or :obj:`train_index` attributes. (default: :obj:`None`) input_train_time (torch.Tensor, optional): The timestamp of training nodes. (default: :obj:`None`) input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of validation nodes. If not given, will try to automatically infer them from the :obj:`data` object by searching for :obj:`val_mask`, :obj:`valid_mask`, :obj:`val_idx`, :obj:`valid_idx`, :obj:`val_index`, or :obj:`valid_index` attributes. (default: :obj:`None`) input_val_time (torch.Tensor, optional): The timestamp of validation edges. (default: :obj:`None`) input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of test nodes. If not given, will try to automatically infer them from the :obj:`data` object by searching for :obj:`test_mask`, :obj:`test_idx`, or :obj:`test_index` attributes. (default: :obj:`None`) input_test_time (torch.Tensor, optional): The timestamp of test nodes. (default: :obj:`None`) input_pred_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of prediction nodes. If not given, will try to automatically infer them from the :obj:`data` object by searching for :obj:`pred_mask`, :obj:`pred_idx`, or :obj:`pred_index` attributes. (default: :obj:`None`) input_pred_time (torch.Tensor, optional): The timestamp of prediction nodes. (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"neighbor"`). (default: :obj:`"neighbor"`) node_sampler (BaseSampler, optional): A custom sampler object to generate mini-batches. If set, will ignore the :obj:`loader` option. (default: :obj:`None`) eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments that override the :class:`torch_geometric.loader.NeighborLoader` configuration during evaluation. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.loader.NeighborLoader`. """ def __init__( self, data: Union[Data, HeteroData], input_train_nodes: InputNodes = None, input_train_time: OptTensor = None, input_val_nodes: InputNodes = None, input_val_time: OptTensor = None, input_test_nodes: InputNodes = None, input_test_time: OptTensor = None, input_pred_nodes: InputNodes = None, input_pred_time: OptTensor = None, loader: str = 'neighbor', node_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: if input_train_nodes is None: input_train_nodes = infer_input_nodes(data, split='train') if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='val') if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='valid') if input_test_nodes is None: input_test_nodes = infer_input_nodes(data, split='test') if input_pred_nodes is None: input_pred_nodes = infer_input_nodes(data, split='pred') super().__init__( data=data, has_val=input_val_nodes is not None, has_test=input_test_nodes is not None, loader=loader, graph_sampler=node_sampler, eval_loader_kwargs=eval_loader_kwargs, **kwargs, ) self.input_train_nodes = input_train_nodes self.input_train_time = input_train_time self.input_train_id: OptTensor = None self.input_val_nodes = input_val_nodes self.input_val_time = input_val_time self.input_val_id: OptTensor = None self.input_test_nodes = input_test_nodes self.input_test_time = input_test_time self.input_test_id: OptTensor = None self.input_pred_nodes = input_pred_nodes self.input_pred_time = input_pred_time self.input_pred_id: OptTensor = None def dataloader( self, input_nodes: InputNodes, input_time: OptTensor = None, input_id: OptTensor = None, node_sampler: Optional[BaseSampler] = None, **kwargs: Any, ) -> torch.utils.data.DataLoader: if self.loader == 'full': return self.full_dataloader(**kwargs) assert node_sampler is not None return NodeLoader( self.data, node_sampler=node_sampler, input_nodes=input_nodes, input_time=input_time, input_id=input_id, **kwargs, ) def train_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_train_nodes, self.input_train_time, self.input_train_id, node_sampler=getattr(self, 'graph_sampler', None), shuffle=self.train_shuffle, **self.loader_kwargs, ) def val_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_val_nodes, self.input_val_time, self.input_val_id, node_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def test_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_test_nodes, self.input_test_time, self.input_test_id, node_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def predict_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_pred_nodes, self.input_pred_time, self.input_pred_id, node_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, )
[docs]class LightningLinkData(LightningData): r"""Converts a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into a :class:`pytorch_lightning.LightningDataModule` variant. It can then be automatically used as a :obj:`datamodule` for multi-GPU link-level training via :lightning:`null` `PyTorch Lightning <https://www.pytorchlightning.ai>`__. :class:`LightningDataset` will take care of providing mini-batches via :class:`~torch_geometric.loader.LinkNeighborLoader`. .. note:: Currently only the :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and :class:`pytorch_lightning.strategies.DDPStrategy` training strategies of :lightning:`null` `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/en/latest/guides/ speed.html>`__ are supported in order to correctly share data across all devices/processes: .. code-block:: python import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule) Args: data (Data or HeteroData or Tuple[FeatureStore, GraphStore]): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object, or a tuple of a :class:`~torch_geometric.data.FeatureStore` and :class:`~torch_geometric.data.GraphStore` objects. input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The training edges. (default: :obj:`None`) input_train_labels (torch.Tensor, optional): The labels of training edges. (default: :obj:`None`) input_train_time (torch.Tensor, optional): The timestamp of training edges. (default: :obj:`None`) input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The validation edges. (default: :obj:`None`) input_val_labels (torch.Tensor, optional): The labels of validation edges. (default: :obj:`None`) input_val_time (torch.Tensor, optional): The timestamp of validation edges. (default: :obj:`None`) input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The test edges. (default: :obj:`None`) input_test_labels (torch.Tensor, optional): The labels of test edges. (default: :obj:`None`) input_test_time (torch.Tensor, optional): The timestamp of test edges. (default: :obj:`None`) input_pred_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The prediction edges. (default: :obj:`None`) input_pred_labels (torch.Tensor, optional): The labels of prediction edges. (default: :obj:`None`) input_pred_time (torch.Tensor, optional): The timestamp of prediction edges. (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"neighbor"`). (default: :obj:`"neighbor"`) link_sampler (BaseSampler, optional): A custom sampler object to generate mini-batches. If set, will ignore the :obj:`loader` option. (default: :obj:`None`) eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments that override the :class:`torch_geometric.loader.LinkNeighborLoader` configuration during evaluation. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.loader.LinkNeighborLoader`. """ def __init__( self, data: Union[Data, HeteroData], input_train_edges: InputEdges = None, input_train_labels: OptTensor = None, input_train_time: OptTensor = None, input_val_edges: InputEdges = None, input_val_labels: OptTensor = None, input_val_time: OptTensor = None, input_test_edges: InputEdges = None, input_test_labels: OptTensor = None, input_test_time: OptTensor = None, input_pred_edges: InputEdges = None, input_pred_labels: OptTensor = None, input_pred_time: OptTensor = None, loader: str = 'neighbor', link_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: super().__init__( data=data, has_val=input_val_edges is not None, has_test=input_test_edges is not None, loader=loader, graph_sampler=link_sampler, eval_loader_kwargs=eval_loader_kwargs, **kwargs, ) self.input_train_edges = input_train_edges self.input_train_labels = input_train_labels self.input_train_time = input_train_time self.input_train_id: OptTensor = None self.input_val_edges = input_val_edges self.input_val_labels = input_val_labels self.input_val_time = input_val_time self.input_val_id: OptTensor = None self.input_test_edges = input_test_edges self.input_test_labels = input_test_labels self.input_test_time = input_test_time self.input_test_id: OptTensor = None self.input_pred_edges = input_pred_edges self.input_pred_labels = input_pred_labels self.input_pred_time = input_pred_time self.input_pred_id: OptTensor = None def dataloader( self, input_edges: InputEdges, input_labels: OptTensor = None, input_time: OptTensor = None, input_id: OptTensor = None, link_sampler: Optional[BaseSampler] = None, **kwargs: Any, ) -> torch.utils.data.DataLoader: if self.loader == 'full': return self.full_dataloader(**kwargs) assert link_sampler is not None return LinkLoader( self.data, link_sampler=link_sampler, edge_label_index=input_edges, edge_label=input_labels, edge_label_time=input_time, input_id=input_id, **kwargs, ) def train_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_train_edges, self.input_train_labels, self.input_train_time, self.input_train_id, link_sampler=getattr(self, 'graph_sampler', None), shuffle=self.train_shuffle, **self.loader_kwargs, ) def val_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_val_edges, self.input_val_labels, self.input_val_time, self.input_val_id, link_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def test_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_test_edges, self.input_test_labels, self.input_test_time, self.input_test_id, link_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, ) def predict_dataloader(self) -> torch.utils.data.DataLoader: return self.dataloader( self.input_pred_edges, self.input_pred_labels, self.input_pred_time, self.input_pred_id, link_sampler=getattr(self, 'eval_graph_sampler', None), shuffle=False, **self.eval_loader_kwargs, )
############################################################################### # TODO Support Tuple[FeatureStore, GraphStore] def infer_input_nodes(data: Union[Data, HeteroData], split: str) -> InputNodes: attr_name: Optional[str] = None if f'{split}_mask' in data: attr_name = f'{split}_mask' elif f'{split}_idx' in data: attr_name = f'{split}_idx' elif f'{split}_index' in data: attr_name = f'{split}_index' if attr_name is None: return None if isinstance(data, Data): return data[attr_name] if isinstance(data, HeteroData): input_nodes_dict = { node_type: store[attr_name] for node_type, store in data.node_items() if attr_name in store } if len(input_nodes_dict) != 1: raise ValueError(f"Could not automatically determine the input " f"nodes of {data} since there exists multiple " f"types with attribute '{attr_name}'") return list(input_nodes_dict.items())[0] return None def kwargs_repr(**kwargs: Any) -> str: return ', '.join([f'{k}={v}' for k, v in kwargs.items() if v is not None]) def split_kwargs( kwargs: Dict[str, Any], sampler_cls: Type, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: r"""Splits keyword arguments into sampler and loader arguments.""" sampler_args = inspect.signature(sampler_cls).parameters sampler_kwargs: Dict[str, Any] = {} loader_kwargs: Dict[str, Any] = {} for key, value in kwargs.items(): if key in sampler_args: sampler_kwargs[key] = value else: loader_kwargs[key] = value return sampler_kwargs, loader_kwargs