Source code for torch_geometric.loader.link_neighbor_loader

from typing import Callable, Optional, Tuple, Union

from import Data, HeteroData
from import FeatureStore
from import GraphStore
from torch_geometric.loader.link_loader import LinkLoader
from torch_geometric.loader.utils import get_edge_label_index
from torch_geometric.sampler import NeighborSampler
from torch_geometric.typing import InputEdges, NumNeighbors, OptTensor

[docs]class LinkNeighborLoader(LinkLoader): r"""A link-based data loader derived as an extension of the node-based :class:`torch_geometric.loader.NeighborLoader`. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. More specifically, this loader first selects a sample of edges from the set of input edges :obj:`edge_label_index` (which may or not be edges in the original graph) and then constructs a subgraph from all the nodes present in this list by sampling :obj:`num_neighbors` neighbors in each iteration. .. code-block:: python from torch_geometric.datasets import Planetoid from torch_geometric.loader import LinkNeighborLoader data = Planetoid(path, name='Cora')[0] loader = LinkNeighborLoader( data, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, edge_label_index=data.edge_index, ) sampled_data = next(iter(loader)) print(sampled_data) >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368], train_mask=[1368], val_mask=[1368], test_mask=[1368], edge_label_index=[2, 128]) It is additionally possible to provide edge labels for sampled edges, which are then added to the batch: .. code-block:: python loader = LinkNeighborLoader( data, num_neighbors=[30] * 2, batch_size=128, edge_label_index=data.edge_index, edge_label=torch.ones(data.edge_index.size(1)) ) sampled_data = next(iter(loader)) print(sampled_data) >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368], train_mask=[1368], val_mask=[1368], test_mask=[1368], edge_label_index=[2, 128], edge_label=[128]) The rest of the functionality mirrors that of :class:`~torch_geometric.loader.NeighborLoader`, including support for heterogenous graphs. .. note:: :obj:`neg_sampling_ratio` is currently implemented in an approximate way, *i.e.* negative edges may contain false negatives. Args: data ( or The :class:`` or :class:`` graph object. num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The number of neighbors to sample for each node in each iteration. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. If an entry is set to :obj:`-1`, all neighbors will be included. edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The edge indices for which neighbors are sampled to create mini-batches. If set to :obj:`None`, all edges will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the edge type and corresponding edge indices. (default: :obj:`None`) edge_label (Tensor, optional): The labels of edge indices for which neighbors are sampled. Must be the same length as the :obj:`edge_label_index`. If set to :obj:`None` its set to `torch.zeros(...)` internally. (default: :obj:`None`) edge_label_time (Tensor, optional): The timestamps for edge indices for which neighbors are sampled. Must be the same length as :obj:`edge_label_index`. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.*, neighbors have an earlier timestamp than the ouput edge. The :obj:`time_attr` needs to be set for this to work. (default: :obj:`None`) replace (bool, optional): If set to :obj:`True`, will sample with replacement. (default: :obj:`False`) directed (bool, optional): If set to :obj:`False`, will include all edges between all sampled nodes. (default: :obj:`True`) neg_sampling_ratio (float, optional): The ratio of sampled negative edges to the number of positive edges. If :obj:`neg_sampling_ratio > 0` and in case :obj:`edge_label` does not exist, it will be automatically created and represents a binary classification task (:obj:`1` = edge, :obj:`0` = no edge). If :obj:`neg_sampling_ratio > 0` and in case :obj:`edge_label` exists, it has to be a categorical label from :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. Note that returned labels are of type :obj:`torch.float` for binary classification (to facilitate the ease-of-use of :meth:`F.binary_cross_entropy`) and of type :obj:`torch.long` for multi-class classification (to facilitate the ease-of-use of :meth:`F.cross_entropy`). (default: :obj:`0.0`). time_attr (str, optional): The name of the attribute that denotes timestamps for the nodes in the graph. Only used if :obj:`edge_label_time` is set. (default: :obj:`None`) transform (Callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returning data in each worker's subprocess rather than in the main process. Setting this to :obj:`True` is generally not recommended: (1) it may result in too many open file handles, (2) it may slown down data loading, (3) it requires operating on CPU tensors. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:``, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: NumNeighbors, edge_label_index: InputEdges = None, edge_label: OptTensor = None, edge_label_time: OptTensor = None, replace: bool = False, directed: bool = True, neg_sampling_ratio: float = 0.0, time_attr: Optional[str] = None, transform: Callable = None, is_sorted: bool = False, filter_per_worker: bool = False, neighbor_sampler: Optional[NeighborSampler] = None, **kwargs, ): # Get input type: # TODO(manan): this computation is required twice, once here and once # in LinkLoader: edge_type, _ = get_edge_label_index(data, edge_label_index) has_time_attr = time_attr is not None has_edge_label_time = edge_label_time is not None if has_edge_label_time != has_time_attr: raise ValueError( f"Received conflicting 'time_attr' and 'edge_label_time' " f"arguments: 'time_attr' was " f"{'set' if has_time_attr else 'not set'} and " f"'edge_label_time' was " f"{'set' if has_edge_label_time else 'not set'}. Please " f"resolve these conflicting arguments.") if neighbor_sampler is None: neighbor_sampler = NeighborSampler( data, num_neighbors=num_neighbors, replace=replace, directed=directed, input_type=edge_type, time_attr=time_attr, is_sorted=is_sorted, share_memory=kwargs.get('num_workers', 0) > 0, ) super().__init__( data=data, link_sampler=neighbor_sampler, edge_label_index=edge_label_index, edge_label=edge_label, edge_label_time=edge_label_time, neg_sampling_ratio=neg_sampling_ratio, transform=transform, filter_per_worker=filter_per_worker, **kwargs, )