from typing import List
import torch
from torch_geometric.data import TemporalData
[docs]class TemporalDataLoader(torch.utils.data.DataLoader):
r"""A data loader which merges succesive events of a
:class:`torch_geometric.data.TemporalData` to a mini-batch.
Args:
data (TemporalData): The :obj:`~torch_geometric.data.TemporalData`
from which to load the data.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`.
"""
def __init__(self, data: TemporalData, batch_size: int = 1, **kwargs):
# Remove for PyTorch Lightning:
kwargs.pop('dataset', None)
kwargs.pop('collate_fn', None)
kwargs.pop('shuffle', None)
self.data = data
self.events_per_batch = batch_size
if kwargs.get('drop_last', False) and len(data) % batch_size != 0:
arange = range(0, len(data) - batch_size, batch_size)
else:
arange = range(0, len(data), batch_size)
super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs)
def __call__(self, arange: List[int]) -> TemporalData:
return self.data[arange[0]:arange[0] + self.events_per_batch]