Source code for torch_geometric.loader.temporal_dataloader

from typing import List

import torch

from import TemporalData

[docs]class TemporalDataLoader( r"""A data loader which merges succesive events of a :class:`` to a mini-batch. Args: data (TemporalData): The :obj:`` from which to load the data. batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) **kwargs (optional): Additional arguments of :class:``. """ def __init__(self, data: TemporalData, batch_size: int = 1, **kwargs): if 'collate_fn' in kwargs: del kwargs['collate_fn'] if 'shuffle' in kwargs: del kwargs['shuffle'] = data self.events_per_batch = batch_size if kwargs.get('drop_last', False) and len(data) % batch_size != 0: arange = range(0, len(data) - batch_size, batch_size) else: arange = range(0, len(data), batch_size) super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs) def __call__(self, arange: List[int]) -> TemporalData: return[arange[0]:arange[0] + self.events_per_batch]