Source code for torch_geometric.data.lightning_datamodule

from typing import Optional, Union

import warnings

import torch

from torch_geometric.typing import InputNodes
from torch_geometric.data import Data, HeteroData, Dataset

from torch_geometric.loader.dataloader import DataLoader
from torch_geometric.loader.neighbor_loader import (NeighborSampler,
                                                    NeighborLoader,
                                                    get_input_node_type)

try:
    from pytorch_lightning import LightningDataModule as PLLightningDataModule
    no_pytorch_lightning = False
except (ImportError, ModuleNotFoundError):
    PLLightningDataModule = object
    no_pytorch_lightning = True


class LightningDataModule(PLLightningDataModule):
    def __init__(self, has_val: bool, has_test: bool, **kwargs):
        super().__init__()

        if no_pytorch_lightning:
            raise ModuleNotFoundError(
                "No module named 'pytorch_lightning' found on this machine. "
                "Run 'pip install pytorch_lightning' to install the library")

        if not has_val:
            self.val_dataloader = None

        if not has_test:
            self.test_dataloader = None

        if 'shuffle' in kwargs:
            warnings.warn(f"The 'shuffle={kwargs['shuffle']}' option is "
                          f"ignored in '{self.__class__.__name__}'. Remove it "
                          f"from the argument list to disable this warning")
            del kwargs['shuffle']

        if 'pin_memory' not in kwargs:
            kwargs['pin_memory'] = True

        if 'persistent_workers' not in kwargs:
            kwargs['persistent_workers'] = kwargs.get('num_workers', 0) > 0

        self.kwargs = kwargs

    def prepare_data(self):
        from pytorch_lightning.plugins import SingleDevicePlugin
        from pytorch_lightning.plugins import DDPSpawnPlugin
        plugin = self.trainer.training_type_plugin
        if not isinstance(plugin, (SingleDevicePlugin, DDPSpawnPlugin)):
            raise NotImplementedError(
                f"'{self.__class__.__name__}' currently only supports "
                f"'{SingleDevicePlugin.__name__}' and "
                f"'{DDPSpawnPlugin.__name__}' training type plugins of "
                f"'pytorch_lightning' (got '{plugin.__class__.__name__}')")

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self._kwargs_repr(**self.kwargs)})'


[docs]class LightningDataset(LightningDataModule): r"""Converts a set of :class:`~torch_geometric.data.Dataset` objects into a :class:`pytorch_lightning.LightningDataModule` variant, which can be automatically used as a :obj:`datamodule` for multi-GPU graph-level training via `PyTorch Lightning <https://www.pytorchlightning.ai>`__. :class:`LightningDataset` will take care of providing mini-batches via :class:`~torch_geometric.loader.DataLoader`. .. note:: Currently only the :class:`pytorch_lightning.plugins.SingleDevicePlugin` and :class:`pytorch_lightning.plugins.DDPSpawnPlugin` training type plugins of `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/en/ latest/guides/speed.html>`__ are supported in order to correctly share data across all devices/processes: .. code-block:: import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", gpus=4) trainer.fit(model, datamodule) Args: train_dataset: (Dataset) The training dataset. val_dataset: (Dataset, optional) The validation dataset. (default: :obj:`None`) test_dataset: (Dataset, optional) The test dataset. (default: :obj:`None`) batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) num_workers: How many subprocesses to use for data loading. :obj:`0` means that the data will be loaded in the main process. (default: :obj:`0`) **kwargs (optional): Additional arguments of :class:`torch_geometric.loader.DataLoader`. """ def __init__( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, batch_size: int = 1, num_workers: int = 0, **kwargs, ): super().__init__( has_val=val_dataset is not None, has_test=test_dataset is not None, batch_size=batch_size, num_workers=num_workers, **kwargs, ) self.train_dataset = train_dataset self.val_dataset = val_dataset self.test_dataset = test_dataset def dataloader(self, dataset: Dataset, shuffle: bool) -> DataLoader: return DataLoader(dataset, shuffle=shuffle, **self.kwargs) def train_dataloader(self) -> DataLoader: """""" from torch.utils.data import IterableDataset shuffle = not isinstance(self.train_dataset, IterableDataset) return self.dataloader(self.train_dataset, shuffle=shuffle) def val_dataloader(self) -> DataLoader: """""" return self.dataloader(self.val_dataset, shuffle=False) def test_dataloader(self) -> DataLoader: """""" return self.dataloader(self.test_dataset, shuffle=False) def __repr__(self) -> str: kwargs = kwargs_repr(train_dataset=self.train_dataset, val_dataset=self.val_dataset, test_dataset=self.test_dataset, **self.kwargs) return f'{self.__class__.__name__}({kwargs})'
[docs]class LightningNodeData(LightningDataModule): r"""Converts a :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object into a :class:`pytorch_lightning.LightningDataModule` variant, which can be automatically used as a :obj:`datamodule` for multi-GPU node-level training via `PyTorch Lightning <https://www.pytorchlightning.ai>`_. :class:`LightningDataset` will take care of providing mini-batches via :class:`~torch_geometric.loader.NeighborLoader`. .. note:: Currently only the :class:`pytorch_lightning.plugins.SingleDevicePlugin` and :class:`pytorch_lightning.plugins.DDPSpawnPlugin` training type plugins of `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/en/ latest/guides/speed.html>`__ are supported in order to correctly share data across all devices/processes: .. code-block:: import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", gpus=4) trainer.fit(model, datamodule) Args: data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. input_train_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of training nodes. If not given, will try to automatically infer them from the :obj:`data` object. (default: :obj:`None`) input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of validation nodes. If not given, will try to automatically infer them from the :obj:`data` object. (default: :obj:`None`) input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The indices of test nodes. If not given, will try to automatically infer them from the :obj:`data` object. (default: :obj:`None`) loader (str): The scalability technique to use (:obj:`"full"`, :obj:`"neighbor"`). (default: :obj:`"neighbor"`) batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) num_workers: How many subprocesses to use for data loading. :obj:`0` means that the data will be loaded in the main process. (default: :obj:`0`) **kwargs (optional): Additional arguments of :class:`torch_geometric.loader.NeighborLoader`. """ def __init__( self, data: Union[Data, HeteroData], input_train_nodes: InputNodes = None, input_val_nodes: InputNodes = None, input_test_nodes: InputNodes = None, loader: str = "neighbor", batch_size: int = 1, num_workers: int = 0, **kwargs, ): assert loader in ['full', 'neighbor'] if input_train_nodes is None: input_train_nodes = infer_input_nodes(data, split='train') if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='val') if input_val_nodes is None: input_val_nodes = infer_input_nodes(data, split='valid') if input_test_nodes is None: input_test_nodes = infer_input_nodes(data, split='test') if loader == 'full' and batch_size != 1: warnings.warn(f"Re-setting 'batch_size' to 1 in " f"'{self.__class__.__name__}' for loader='full' " f"(got '{batch_size}')") batch_size = 1 if loader == 'full' and num_workers != 0: warnings.warn(f"Re-setting 'num_workers' to 0 in " f"'{self.__class__.__name__}' for loader='full' " f"(got '{num_workers}')") num_workers = 0 super().__init__( has_val=input_val_nodes is not None, has_test=input_test_nodes is not None, batch_size=batch_size, num_workers=num_workers, **kwargs, ) if loader == 'full': if kwargs.get('pin_memory', False): warnings.warn(f"Re-setting 'pin_memory' to 'False' in " f"'{self.__class__.__name__}' for loader='full' " f"(got 'True')") self.kwargs['pin_memory'] = False self.data = data self.loader = loader if loader == 'neighbor': self.neighbor_sampler = NeighborSampler( data=data, num_neighbors=kwargs.get('num_neighbors', None), replace=kwargs.get('replace', False), directed=kwargs.get('directed', True), input_node_type=get_input_node_type(input_train_nodes), ) self.input_train_nodes = input_train_nodes self.input_val_nodes = input_val_nodes self.input_test_nodes = input_test_nodes def prepare_data(self): """""" if self.loader == 'full': if self.trainer.num_processes != 1 or self.trainer.num_gpus != 1: raise ValueError(f"'{self.__class__.__name__}' with loader=" f"'full' requires training on a single GPU") super().prepare_data() def dataloader(self, input_nodes: InputNodes) -> DataLoader: if self.loader == 'full': warnings.filterwarnings('ignore', '.*does not have many workers.*') warnings.filterwarnings('ignore', '.*data loading bottlenecks.*') return torch.utils.data.DataLoader([self.data], shuffle=False, collate_fn=lambda xs: xs[0], **self.kwargs) if self.loader == 'neighbor': warnings.filterwarnings('ignore', '.*has `shuffle=True`.*') return NeighborLoader(self.data, input_nodes=input_nodes, neighbor_sampler=self.neighbor_sampler, shuffle=True, **self.kwargs) raise NotImplementedError def train_dataloader(self) -> DataLoader: """""" return self.dataloader(self.input_train_nodes) def val_dataloader(self) -> DataLoader: """""" return self.dataloader(self.input_val_nodes) def test_dataloader(self) -> DataLoader: """""" return self.dataloader(self.input_test_nodes) def __repr__(self) -> str: kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs) return f'{self.__class__.__name__}({kwargs})'
############################################################################### def infer_input_nodes(data: Union[Data, HeteroData], split: str) -> InputNodes: attr_name: Optional[str] = None if f'{split}_mask' in data: attr_name = f'{split}_mask' elif f'{split}_idx' in data: attr_name = f'{split}_idx' elif f'{split}_index' in data: attr_name = f'{split}_index' if attr_name is None: return None if isinstance(data, Data): return data[attr_name] if isinstance(data, HeteroData): input_nodes_dict = data.collect(attr_name) if len(input_nodes_dict) != 1: raise ValueError(f"Could not automatically determine the input " f"nodes of {data} since there exists multiple " f"types with attribute '{attr_name}'") return list(input_nodes_dict.items())[0] return None def kwargs_repr(**kwargs) -> str: return ', '.join([f'{k}={v}' for k, v in kwargs.items() if v is not None])