from typing import List
import torch
from torch_geometric.data import TemporalData
[docs]class TemporalDataLoader(torch.utils.data.DataLoader):
r"""A data loader which merges succesive events of a
:class:`torch_geometric.data.TemporalData` to a mini-batch.
Args:
data (TemporalData): The :obj:`~torch_geometric.data.TemporalData`
from which to load the data.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`.
"""
def __init__(self, data: TemporalData, batch_size: int = 1, **kwargs):
if 'collate_fn' in kwargs:
del kwargs['collate_fn']
if 'shuffle' in kwargs:
del kwargs['shuffle']
self.data = data
self.events_per_batch = batch_size
if kwargs.get('drop_last', False) and len(data) % batch_size != 0:
arange = range(0, len(data) - batch_size, batch_size)
else:
arange = range(0, len(data), batch_size)
super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs)
def __call__(self, arange: List[int]) -> TemporalData:
return self.data[arange[0]:arange[0] + self.events_per_batch]