Source code for torch_geometric.loader.temporal_dataloader

from typing import List

import torch

from import TemporalData

[docs]class TemporalDataLoader( r"""A data loader which merges succesive events of a :class:`` to a mini-batch. Args: data (TemporalData): The :obj:`` from which to load the data. batch_size (int, optional): How many samples per batch to load. (default: :obj:`1`) **kwargs (optional): Additional arguments of :class:``. """ def __init__(self, data: TemporalData, batch_size: int = 1, **kwargs): # Remove for PyTorch Lightning: kwargs.pop('dataset', None) kwargs.pop('collate_fn', None) kwargs.pop('shuffle', None) = data self.events_per_batch = batch_size if kwargs.get('drop_last', False) and len(data) % batch_size != 0: arange = range(0, len(data) - batch_size, batch_size) else: arange = range(0, len(data), batch_size) super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs) def __call__(self, arange: List[int]) -> TemporalData: return[arange[0]:arange[0] + self.events_per_batch]