Source code for torch_geometric.loader.hgt_loader

from typing import Callable, Dict, List, Optional, Tuple, Union

from torch import Tensor

from import FeatureStore, GraphStore, HeteroData
from torch_geometric.loader import NodeLoader
from torch_geometric.sampler import HGTSampler
from torch_geometric.typing import NodeType

[docs]class HGTLoader(NodeLoader): r"""The Heterogeneous Graph Sampler from the `"Heterogeneous Graph Transformer" <>`_ paper. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. :class:`` tries to (1) keep a similar number of nodes and edges for each type and (2) keep the sampled sub-graph dense to minimize the information loss and reduce the sample variance. Methodically, :class:`` keeps track of a node budget for each node type, which is then used to determine the sampling probability of a node. In particular, the probability of sampling a node is determined by the number of connections to already sampled nodes and their node degrees. With this, :class:`` will sample a fixed amount of neighbors for each node type in each iteration, as given by the :obj:`num_samples` argument. Sampled nodes are sorted based on the order in which they were sampled. In particular, the first :obj:`batch_size` nodes represent the set of original mini-batch nodes. .. note:: For an example of using :class:``, see `examples/hetero/ < hetero/>`_. .. code-block:: python from torch_geometric.loader import HGTLoader from torch_geometric.datasets import OGB_MAG hetero_data = OGB_MAG(path)[0] loader = HGTLoader( hetero_data, # Sample 512 nodes per type and per iteration for 4 iterations num_samples={key: [512] * 4 for key in hetero_data.node_types}, # Use a batch size of 128 for sampling training nodes of type paper batch_size=128, input_nodes=('paper', hetero_data['paper'].train_mask), ) sampled_hetero_data = next(iter(loader)) print(sampled_data.batch_size) >>> 128 Args: data (Any): A :class:``, :class:``, or (:class:``, :class:``) data object. num_samples (List[int] or Dict[str, List[int]]): The number of nodes to sample in each iteration and for each node type. If given as a list, will sample the same amount of nodes for each node type. input_nodes (str or Tuple[str, torch.Tensor]): The indices of nodes for which neighbors are sampled to create mini-batches. Needs to be passed as a tuple that holds the node type and corresponding node indices. Node indices need to be either given as a :obj:`torch.LongTensor` or :obj:`torch.BoolTensor`. If node indices are set to :obj:`None`, all nodes of this specific type will be considered. transform (callable, optional): A function/transform that takes in an a sampled mini-batch and returns a transformed version. (default: :obj:`None`) transform_sampler_output (callable, optional): A function/transform that takes in a :class:`torch_geometric.sampler.SamplerOutput` and returns a transformed version. (default: :obj:`None`) is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returned data in each worker's subprocess. If set to :obj:`False`, will filter the returned data in the main process. If set to :obj:`None`, will automatically infer the decision based on whether data partially lives on the GPU (:obj:`filter_per_worker=True`) or entirely on the CPU (:obj:`filter_per_worker=False`). There exists different trade-offs for setting this option. Specifically, setting this option to :obj:`True` for in-memory datasets will move all features to shared memory, which may result in too many open file handles. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:``, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Union[HeteroData, Tuple[FeatureStore, GraphStore]], num_samples: Union[List[int], Dict[NodeType, List[int]]], input_nodes: Union[NodeType, Tuple[NodeType, Optional[Tensor]]], is_sorted: bool = False, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, filter_per_worker: Optional[bool] = None, **kwargs, ): hgt_sampler = HGTSampler( data, num_samples=num_samples, is_sorted=is_sorted, share_memory=kwargs.get('num_workers', 0) > 0, ) super().__init__( data=data, node_sampler=hgt_sampler, input_nodes=input_nodes, transform=transform, transform_sampler_output=transform_sampler_output, filter_per_worker=filter_per_worker, **kwargs, )