Source code for torch_geometric.loader.shadow

import copy
from typing import Optional

import torch
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.data import Batch, Data


[docs]class ShaDowKHopSampler(torch.utils.data.DataLoader): r"""The ShaDow :math:`k`-hop sampler from the `"Decoupling the Depth and Scope of Graph Neural Networks" <https://arxiv.org/abs/2201.07858>`_ paper. Given a graph in a :obj:`data` object, the sampler will create shallow, localized subgraphs. A deep GNN on this local graph then smooths the informative local signals. .. note:: For an example of using :class:`ShaDowKHopSampler`, see `examples/shadow.py <https://github.com/pyg-team/ pytorch_geometric/blob/master/examples/shadow.py>`_. Args: data (torch_geometric.data.Data): The graph data object. depth (int): The depth/number of hops of the localized subgraph. num_neighbors (int): The number of neighbors to sample for each node in each hop. node_idx (LongTensor or BoolTensor, optional): The nodes that should be considered for creating mini-batches. If set to :obj:`None`, all nodes will be considered. replace (bool, optional): If set to :obj:`True`, will sample neighbors with replacement. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or :obj:`num_workers`. """ def __init__(self, data: Data, depth: int, num_neighbors: int, node_idx: Optional[Tensor] = None, replace: bool = False, **kwargs): self.data = copy.copy(data) self.depth = depth self.num_neighbors = num_neighbors self.replace = replace if data.edge_index is not None: self.is_sparse_tensor = False row, col = data.edge_index.cpu() self.adj_t = SparseTensor( row=row, col=col, value=torch.arange(col.size(0)), sparse_sizes=(data.num_nodes, data.num_nodes)).t() else: self.is_sparse_tensor = True self.adj_t = data.adj_t.cpu() if node_idx is None: node_idx = torch.arange(self.adj_t.sparse_size(0)) elif node_idx.dtype == torch.bool: node_idx = node_idx.nonzero(as_tuple=False).view(-1) self.node_idx = node_idx super().__init__(node_idx.tolist(), collate_fn=self.__collate__, **kwargs) def __collate__(self, n_id): n_id = torch.tensor(n_id) rowptr, col, value = self.adj_t.csr() out = torch.ops.torch_sparse.ego_k_hop_sample_adj( rowptr, col, n_id, self.depth, self.num_neighbors, self.replace) rowptr, col, n_id, e_id, ptr, root_n_id = out adj_t = SparseTensor(rowptr=rowptr, col=col, value=value[e_id] if value is not None else None, sparse_sizes=(n_id.numel(), n_id.numel()), is_sorted=True) batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()), ptr=ptr) batch.root_n_id = root_n_id if self.is_sparse_tensor: batch.adj_t = adj_t else: row, col, e_id = adj_t.t().coo() batch.edge_index = torch.stack([row, col], dim=0) for k, v in self.data: if k in ['edge_index', 'adj_t', 'num_nodes']: continue if k == 'y' and v.size(0) == self.data.num_nodes: batch[k] = v[n_id][root_n_id] elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes: batch[k] = v[n_id] elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges: batch[k] = v[e_id] else: batch[k] = v return batch