Source code for

import copy
import os.path as osp
from typing import Optional

import torch
from tqdm import tqdm
from torch_sparse import SparseTensor

[docs]class GraphSAINTSampler( r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" <>`_ paper. Given a graph in a :obj:`data` object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via :obj:`node_norm` and :obj:`edge_norm` data attributes. .. note:: See :class:``, :class:`` and :class:`` for currently supported samplers. For an example of using GraphSAINT sampling, see `examples/ < blob/master/examples/>`_. Args: data ( The graph data object. batch_size (int): The approximate number of samples per batch. num_steps (int, optional): The number of iterations per epoch. (default: :obj:`1`) sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`0`) save_dir (string, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) log (bool, optional): If set to :obj:`False`, will not log any pre-processing progress. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:``, such as :obj:`batch_size` or :obj:`num_workers`. """ def __init__(self, data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data self.num_steps = num_steps self.__batch_size__ = batch_size self.sample_coverage = sample_coverage self.log = log self.N = N = data.num_nodes self.E = data.num_edges self.adj_t = SparseTensor( row=data.edge_index[0], col=data.edge_index[1], value=torch.arange(self.E, device=data.edge_index.device), sparse_sizes=(N, N)).t() = copy.copy(data) = None super(GraphSAINTSampler, self).__init__(self, batch_size=1, collate_fn=self.__collate__, **kwargs) if self.sample_coverage > 0: path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): # pragma: no cover self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: # pragma: no cover, self.edge_norm), path) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __len__(self): return self.num_steps def __sample_nodes__(self, batch_size): raise NotImplementedError def __getitem__(self, idx): node_idx = self.__sample_nodes__(self.__batch_size__).unique() adj_t, edge_idx = self.adj_t.saint_subgraph(node_idx) edge_idx =[edge_idx] return node_idx, edge_idx, adj_t def __collate__(self, data_list): assert len(data_list) == 1 node_idx, edge_idx, adj_t = data_list[0] data = data.num_nodes = node_idx.size(0) col, row, edge_idx = adj_t.coo() data.edge_index = torch.stack([row, col], dim=0) for key, item in if item.size(0) == self.N: data[key] = item[node_idx] elif item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item if self.sample_coverage > 0: data.node_norm = self.node_norm[node_idx] data.edge_norm = self.edge_norm[edge_idx] return data def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) loader =, batch_size=200, collate_fn=lambda x: x, num_workers=self.num_workers) if self.log: # pragma: no cover pbar = tqdm(total=self.N * self.sample_coverage) pbar.set_description('Compute GraphSAINT normalization') num_samples = total_sampled_nodes = 0 while total_sampled_nodes < self.N * self.sample_coverage: for data in loader: for (node_idx, edge_idx, _) in data: node_count[node_idx] += 1 edge_count[edge_idx] += 1 total_sampled_nodes += node_idx.size(0) if self.log: # pragma: no cover pbar.update(node_idx.size(0)) num_samples += 200 if self.log: # pragma: no cover pbar.close() row, _, edge_idx = self.adj_t.coo() edge_norm = (node_count[row[edge_idx]] / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / node_count / self.N return node_norm, edge_norm
[docs]class GraphSAINTNodeSampler(GraphSAINTSampler): r"""The GraphSAINT node sampler class (see :class:``). """ def __sample_nodes__(self, batch_size): edge_sample = torch.randint(0, self.E, (batch_size, self.batch_size), dtype=torch.long) return[edge_sample]
[docs]class GraphSAINTEdgeSampler(GraphSAINTSampler): r"""The GraphSAINT edge sampler class (see :class:``). """ def __sample_nodes__(self, batch_size): row, col, _ = self.adj_t.coo() deg_in = 1. / deg_out = 1. / prob = (1. / deg_in[row]) + (1. / deg_out[col]) # Parallel multinomial sampling (without replacement) # rand = torch.rand(batch_size, self.E).log() / (prob + 1e-10) edge_sample = rand.topk(self.batch_size, dim=-1).indices source_node_sample = col[edge_sample] target_node_sample = row[edge_sample] return[source_node_sample, target_node_sample], -1)
[docs]class GraphSAINTRandomWalkSampler(GraphSAINTSampler): r"""The GraphSAINT random walk sampler class (see :class:``). Args: walk_length (int): The length of each random walk. """ def __init__(self, data, batch_size: int, walk_length: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs): self.walk_length = walk_length super(GraphSAINTRandomWalkSampler, self).__init__(data, batch_size, num_steps, sample_coverage, save_dir, log, **kwargs) @property def __filename__(self): return (f'{self.__class__.__name__.lower()}_{self.walk_length}_' f'{self.sample_coverage}.pt') def __sample_nodes__(self, batch_size): start = torch.randint(0, self.N, (batch_size, ), dtype=torch.long) node_idx = self.adj_t.random_walk(start.flatten(), self.walk_length) return node_idx.view(-1)