from typing import Union, List
from collections.abc import Mapping, Sequence
import torch.utils.data
from torch.utils.data.dataloader import default_collate
from torch_geometric.data import Data, HeteroData, Dataset, Batch
class Collater(object):
def __init__(self, follow_batch, exclude_keys):
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
def collate(self, batch):
elem = batch[0]
if isinstance(elem, Data) or isinstance(elem, HeteroData):
return Batch.from_data_list(batch, self.follow_batch,
self.exclude_keys)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, Mapping):
return {key: self.collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
return type(elem)(*(self.collate(s) for s in zip(*batch)))
elif isinstance(elem, Sequence) and not isinstance(elem, str):
return [self.collate(s) for s in zip(*batch)]
raise TypeError('DataLoader found invalid type: {}'.format(type(elem)))
def __call__(self, batch):
return self.collate(batch)
[docs]class DataLoader(torch.utils.data.DataLoader):
r"""A data loader which merges data objects from a
:class:`torch_geometric.data.Dataset` to a mini-batch.
Data objects can be either of type :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData`.
Args:
dataset (Dataset): The dataset from which to load the data.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch. (default: :obj:`False`)
follow_batch (List[str], optional): Creates assignment batch
vectors for each key in the list. (default: :obj:`[]`)
exclude_keys (List[str], optional): Will exclude each key in the
list. (default: :obj:`[]`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`.
"""
def __init__(
self,
dataset: Union[Dataset, List[Data], List[HeteroData]],
batch_size: int = 1,
shuffle: bool = False,
follow_batch: List[str] = [],
exclude_keys: List[str] = [],
**kwargs,
):
if "collate_fn" in kwargs:
del kwargs["collate_fn"]
# Save for PyTorch Lightning...
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
super().__init__(dataset, batch_size, shuffle,
collate_fn=Collater(follow_batch,
exclude_keys), **kwargs)