Source code for torch_geometric.loader.link_neighbor_loader

import copy
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch_scatter import scatter_min

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.neighbor_loader import NeighborSampler
from torch_geometric.loader.utils import (
    filter_custom_store,
    filter_data,
    filter_hetero_data,
)
from torch_geometric.typing import InputEdges, NumNeighbors, OptTensor


class LinkNeighborSampler(NeighborSampler):
    def __init__(
        self,
        data,
        *args,
        neg_sampling_ratio: float = 0.0,
        num_src_nodes: Optional[int] = None,
        num_dst_nodes: Optional[int] = None,
        **kwargs,
    ):
        super().__init__(data, *args, **kwargs)
        self.neg_sampling_ratio = neg_sampling_ratio

        # TODO if self.edge_time is not None and
        # `src` or `dst` nodes don't have time attribute
        # i.e node_time_dict[input_type[0/-1]] doesn't exist
        # set it to largest representabel torch.long.

        if self.data_cls == 'custom':
            _, graph_store = data
            edge_attrs = graph_store.get_all_edge_attrs()
            edge_types = [attr.edge_type for attr in edge_attrs]

            # Edge label index is part of the graph:
            if self.input_type in edge_types:
                self.num_src_nodes, self.num_dst_nodes = edge_attrs[
                    edge_types.index(self.input_type)].size

            elif num_src_nodes is None or num_dst_nodes is None:
                raise ValueError(
                    f"Use of the feature store and graph store abstraction "
                    f"with {self.__class__.__name__} requires the "
                    f"specification of source and destination nodes, since "
                    f"the edge label index {self.input_type} is not part "
                    f"of the specified graph. ")

            else:
                self.num_src_nodes = num_src_nodes
                self.num_dst_nodes = num_dst_nodes

        elif issubclass(self.data_cls, Data):
            self.num_src_nodes = self.num_dst_nodes = data.num_nodes
        else:  # issubclass(self.data_cls, HeteroData):
            self.num_src_nodes = data[self.input_type[0]].num_nodes
            self.num_dst_nodes = data[self.input_type[-1]].num_nodes

    def _add_negative_samples(self, edge_label_index, edge_label,
                              edge_label_time):
        """Add negative samples and their `edge_label` and `edge_time`
        if `self.neg_sampling_ration>0`"""
        num_pos_edges = edge_label_index.size(1)
        num_neg_edges = int(num_pos_edges * self.neg_sampling_ratio)

        if num_neg_edges == 0:
            return edge_label_index, edge_label, edge_label_time

        neg_row = torch.randint(self.num_src_nodes, (num_neg_edges, ))
        neg_col = torch.randint(self.num_dst_nodes, (num_neg_edges, ))
        neg_edge_label_index = torch.stack([neg_row, neg_col], dim=0)

        if edge_label_time is not None:
            perm = torch.randperm(num_pos_edges)
            edge_label_time = torch.cat(
                [edge_label_time, edge_label_time[perm[:num_neg_edges]]])

        edge_label_index = torch.cat([
            edge_label_index,
            neg_edge_label_index,
        ], dim=1)

        pos_edge_label = edge_label + 1
        neg_edge_label = edge_label.new_zeros((num_neg_edges, ) +
                                              edge_label.size()[1:])

        edge_label = torch.cat([pos_edge_label, neg_edge_label], dim=0)

        return edge_label_index, edge_label, edge_label_time

    def _get_batch_node_time_dict(self, edge_label_index, edge_label_time):
        """For edges in a batch replace `src` and `dst` node times by the min
        across all edge times."""
        def update_time_(node_time_dict, index, node_type, num_nodes):
            node_time_dict[node_type] = node_time_dict[node_type].clone()
            node_time, _ = scatter_min(edge_label_time, index, dim=0,
                                       dim_size=num_nodes)
            # NOTE We assume that node_time is always less than edge_time.
            index_unique = index.unique()
            node_time_dict[node_type][index_unique] = node_time[index_unique]

        node_time_dict = copy.copy(self.node_time_dict)
        update_time_(node_time_dict, edge_label_index[0], self.input_type[0],
                     self.num_src_nodes)
        update_time_(node_time_dict, edge_label_index[1], self.input_type[-1],
                     self.num_dst_nodes)
        return node_time_dict

    def __call__(self, query: List[Tuple[Tensor]]):
        query = [torch.stack(s, dim=0) for s in zip(*query)]
        edge_label_index = torch.stack(query[:2], dim=0)
        edge_label = query[2]
        edge_label_time = query[3] if len(query) == 4 else None

        out = self._add_negative_samples(edge_label_index, edge_label,
                                         edge_label_time)
        edge_label_index, edge_label, edge_label_time = out

        orig_edge_label_index = edge_label_index
        if (self.data_cls == 'custom'
                or issubclass(self.data_cls, HeteroData)):
            if self.input_type[0] != self.input_type[-1]:
                query_src = edge_label_index[0]
                query_src, reverse_src = query_src.unique(return_inverse=True)
                query_dst = edge_label_index[1]
                query_dst, reverse_dst = query_dst.unique(return_inverse=True)
                edge_label_index = torch.stack([reverse_src, reverse_dst], 0)
                query_node_dict = {
                    self.input_type[0]: query_src,
                    self.input_type[-1]: query_dst,
                }
            else:  # Merge both source and destination node indices:
                query_nodes = edge_label_index.view(-1)
                query_nodes, reverse = query_nodes.unique(return_inverse=True)
                edge_label_index = reverse.view(2, -1)
                query_node_dict = {self.input_type[0]: query_nodes}
            node_time_dict = self.node_time_dict
            if edge_label_time is not None:
                node_time_dict = self._get_batch_node_time_dict(
                    orig_edge_label_index, edge_label_time)
            out = self._hetero_sparse_neighbor_sample(
                query_node_dict, node_time_dict=node_time_dict) + (
                    edge_label_index, edge_label, edge_label_time)
            return out

        elif issubclass(self.data_cls, Data):
            query_nodes = edge_label_index.view(-1)
            query_nodes, reverse = query_nodes.unique(return_inverse=True)
            edge_label_index = reverse.view(2, -1)
            return self._sparse_neighbor_sample(query_nodes) + (
                edge_label_index, edge_label)


[docs]class LinkNeighborLoader(torch.utils.data.DataLoader): r"""A link-based data loader derived as an extension of the node-based :class:`torch_geometric.loader.NeighborLoader`. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. More specifically, this loader first selects a sample of edges from the set of input edges :obj:`edge_label_index` (which may or not be edges in the original graph) and then constructs a subgraph from all the nodes present in this list by sampling :obj:`num_neighbors` neighbors in each iteration. .. code-block:: python from torch_geometric.datasets import Planetoid from torch_geometric.loader import LinkNeighborLoader data = Planetoid(path, name='Cora')[0] loader = LinkNeighborLoader( data, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, edge_label_index=data.edge_index, ) sampled_data = next(iter(loader)) print(sampled_data) >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368], train_mask=[1368], val_mask=[1368], test_mask=[1368], edge_label_index=[2, 128]) It is additionally possible to provide edge labels for sampled edges, which are then added to the batch: .. code-block:: python loader = LinkNeighborLoader( data, num_neighbors=[30] * 2, batch_size=128, edge_label_index=data.edge_index, edge_label=torch.ones(data.edge_index.size(1)) ) sampled_data = next(iter(loader)) print(sampled_data) >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368], train_mask=[1368], val_mask=[1368], test_mask=[1368], edge_label_index=[2, 128], edge_label=[128]) The rest of the functionality mirrors that of :class:`~torch_geometric.loader.NeighborLoader`, including support for heterogenous graphs. .. note:: :obj:`neg_sampling_ratio` is currently implemented in an approximate way, *i.e.* negative edges may contain false negatives. Args: data (torch_geometric.data.Data or torch_geometric.data.HeteroData): The :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` graph object. num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The number of neighbors to sample for each node in each iteration. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. If an entry is set to :obj:`-1`, all neighbors will be included. edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]): The edge indices for which neighbors are sampled to create mini-batches. If set to :obj:`None`, all edges will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the edge type and corresponding edge indices. (default: :obj:`None`) edge_label (Tensor, optional): The labels of edge indices for which neighbors are sampled. Must be the same length as the :obj:`edge_label_index`. If set to :obj:`None` its set to `torch.zeros(...)` internally. (default: :obj:`None`) edge_label_time (Tensor, optional): The timestamps for edge indices for which neighbors are sampled. Must be the same length as :obj:`edge_label_index`. If set, temporal sampling will be used such that neighbors are guaranteed to fulfill temporal constraints, *i.e.*, neighbors have an earlier timestamp than the ouput edge. The :obj:`time_attr` needs to be set for this to work. (default: :obj:`None`) num_src_nodes (optional, int): The number of source nodes in the edge label index. Inferred if not provided. num_dst_nodes (optional, int): The number of destination nodes in the edge label index. Inferred if not provided. replace (bool, optional): If set to :obj:`True`, will sample with replacement. (default: :obj:`False`) directed (bool, optional): If set to :obj:`False`, will include all edges between all sampled nodes. (default: :obj:`True`) neg_sampling_ratio (float, optional): The ratio of sampled negative edges to the number of positive edges. If :obj:`neg_sampling_ratio > 0` and in case :obj:`edge_label` does not exist, it will be automatically created and represents a binary classification task (:obj:`1` = edge, :obj:`0` = no edge). If :obj:`neg_sampling_ratio > 0` and in case :obj:`edge_label` exists, it has to be a categorical label from :obj:`0` to :obj:`num_classes - 1`. After negative sampling, label :obj:`0` represents negative edges, and labels :obj:`1` to :obj:`num_classes` represent the labels of positive edges. Note that returned labels are of type :obj:`torch.float` for binary classification (to facilitate the ease-of-use of :meth:`F.binary_cross_entropy`) and of type :obj:`torch.long` for multi-class classification (to facilitate the ease-of-use of :meth:`F.cross_entropy`). (default: :obj:`0.0`). time_attr (str, optional): The name of the attribute that denotes timestamps for the nodes in the graph. Only used if :obj:`edge_label_time` is set. (default: :obj:`None`) transform (Callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) filter_per_worker (bool, optional): If set to :obj:`True`, will filter the returning data in each worker's subprocess rather than in the main process. Setting this to :obj:`True` is generally not recommended: (1) it may result in too many open file handles, (2) it may slown down data loading, (3) it requires operating on CPU tensors. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. """ def __init__( self, data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: NumNeighbors, edge_label_index: InputEdges = None, edge_label: OptTensor = None, edge_label_time: OptTensor = None, num_src_nodes: Optional[int] = None, num_dst_nodes: Optional[int] = None, replace: bool = False, directed: bool = True, neg_sampling_ratio: float = 0.0, time_attr: Optional[str] = None, transform: Callable = None, is_sorted: bool = False, filter_per_worker: bool = False, neighbor_sampler: Optional[LinkNeighborSampler] = None, **kwargs, ): # Remove for PyTorch Lightning: if 'dataset' in kwargs: del kwargs['dataset'] if 'collate_fn' in kwargs: del kwargs['collate_fn'] self.data = data # Save for PyTorch Lightning < 1.6: self.num_neighbors = num_neighbors self.edge_label = edge_label self.edge_label_index = edge_label_index self.edge_label_time = edge_label_time self.replace = replace self.directed = directed self.neg_sampling_ratio = neg_sampling_ratio self.transform = transform self.filter_per_worker = filter_per_worker self.neighbor_sampler = neighbor_sampler edge_type, edge_label_index = get_edge_label_index( data, edge_label_index) if edge_label is None: edge_label = torch.zeros(edge_label_index.size(1), device=edge_label_index.device) if (edge_label_time is None) != (time_attr is None): raise ValueError("`edge_label_time` is specified but `time_attr` " "is `None` or vice-versa. Both arguments need to " "be specified for temporal sampling") if neighbor_sampler is None: self.neighbor_sampler = LinkNeighborSampler( data, num_neighbors, replace, directed, input_type=edge_type, is_sorted=is_sorted, neg_sampling_ratio=self.neg_sampling_ratio, num_src_nodes=num_src_nodes, num_dst_nodes=num_dst_nodes, time_attr=time_attr, share_memory=kwargs.get('num_workers', 0) > 0, ) super().__init__( Dataset(edge_label_index, edge_label, edge_label_time), collate_fn=self.collate_fn, **kwargs) def filter_fn(self, out: Any) -> Union[Data, HeteroData]: if isinstance(self.data, Data): (node, row, col, edge, edge_label_index, edge_label) = out data = filter_data(self.data, node, row, col, edge, self.neighbor_sampler.perm) data.edge_label_index = edge_label_index data.edge_label = edge_label elif isinstance(self.data, HeteroData): (node_dict, row_dict, col_dict, edge_dict, edge_label_index, edge_label, edge_label_time) = out data = filter_hetero_data(self.data, node_dict, row_dict, col_dict, edge_dict, self.neighbor_sampler.perm_dict) edge_type = self.neighbor_sampler.input_type data[edge_type].edge_label_index = edge_label_index data[edge_type].edge_label = edge_label if edge_label_time is not None: data[edge_type].edge_label_time = edge_label_time else: (node_dict, row_dict, col_dict, edge_dict, edge_label_index, edge_label, edge_label_time) = out feature_store, graph_store = self.data data = filter_custom_store(feature_store, graph_store, node_dict, row_dict, col_dict, edge_dict) edge_type = self.neighbor_sampler.input_type data[edge_type].edge_label_index = edge_label_index data[edge_type].edge_label = edge_label if edge_label_time is None: data[edge_type].edge_label_time = edge_label_time return data if self.transform is None else self.transform(data) def collate_fn(self, index: Union[List[int], Tensor]) -> Any: out = self.neighbor_sampler(index) if self.filter_per_worker: # We execute `filter_fn` in the worker process. out = self.filter_fn(out) return out def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() # We execute `filter_fn` in the main process. return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}()'
############################################################################### class Dataset(torch.utils.data.Dataset): def __init__(self, edge_label_index: Tensor, edge_label: Tensor, edge_label_time: OptTensor = None): self.edge_label_index = edge_label_index self.edge_label = edge_label self.edge_label_time = edge_label_time def __getitem__(self, idx: int) -> Tuple[int]: if self.edge_label_time is None: return ( self.edge_label_index[0, idx], self.edge_label_index[1, idx], self.edge_label[idx], ) else: return ( self.edge_label_index[0, idx], self.edge_label_index[1, idx], self.edge_label[idx], self.edge_label_time[idx], ) def __len__(self) -> int: return self.edge_label_index.size(1) def get_edge_label_index( data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], edge_label_index: InputEdges, ) -> Tuple[Optional[str], Tensor]: edge_type = None if isinstance(data, Data): if edge_label_index is None: return None, data.edge_index return None, edge_label_index assert edge_label_index is not None assert isinstance(edge_label_index, (list, tuple)) if isinstance(data, HeteroData): if isinstance(edge_label_index[0], str): edge_type = edge_label_index edge_type = data._to_canonical(*edge_type) assert edge_type in data.edge_types return edge_type, data[edge_type].edge_index assert len(edge_label_index) == 2 edge_type, edge_label_index = edge_label_index edge_type = data._to_canonical(*edge_type) if edge_label_index is None: return edge_type, data[edge_type].edge_index return edge_type, edge_label_index else: # Tuple[FeatureStore, GraphStore] _, graph_store = data # Need the edge index in COO for LinkNeighborLoader: def _get_edge_index(edge_type): row_dict, col_dict, _ = graph_store.coo([edge_type]) row = list(row_dict.values())[0] col = list(col_dict.values())[0] return torch.stack((row, col), dim=0) if isinstance(edge_label_index[0], str): edge_type = edge_label_index return edge_type, _get_edge_index(edge_type) assert len(edge_label_index) == 2 edge_type, edge_label_index = edge_label_index if edge_label_index is None: return edge_type, _get_edge_index(edge_type) return edge_type, edge_label_index