Source code for torch_geometric.data.graph_saint

import copy
import os.path as osp

import torch
from tqdm import tqdm
from torch.multiprocessing import Queue, Process
from torch_sparse import SparseTensor


[docs]class GraphSAINTSampler(object): r"""The GraphSAINT sampler base class from the `"GraphSAINT: Graph Sampling Based Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper. Given a graph in a :obj:`data` object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given via :obj:`node_norm` and :obj:`edge_norm` data attributes. .. note:: See :class:`torch_geometric.data.GraphSAINTNodeSampler`, :class:`torch_geometric.data.GraphSAINTEdgeSampler` and :class:`torch_geometric.data.GraphSAINTRandomWalkSampler` for currently supported samplers. For an example of using GraphSAINT sampling, see `examples/graph_saint.py <https://github.com/rusty1s/pytorch_geometric/ blob/master/examples/graph_saint.py>`_. Args: data (torch_geometric.data.Data): The graph data object. batch_size (int): The approximate number of samples per batch to load. num_steps (int, optional): The number of iterations. (default: :obj:`1`) sample_coverage (int): How many samples per node should be used to compute normalization statistics. (default: :obj:`50`) save_dir (string, optional): If set, will save normalization statistics to the :obj:`save_dir` directory for faster re-use. (default: :obj:`None`) num_workers (int, optional): How many subprocesses to use for data sampling. :obj:`0` means that the data will be sampled in the main process. (default: :obj:`0`) log (bool, optional): If set to :obj:`False`, will not log any progress. (default: :obj:`True`) """ def __init__(self, data, batch_size, num_steps=1, sample_coverage=50, save_dir=None, num_workers=0, log=True): assert data.edge_index is not None assert 'node_norm' not in data assert 'edge_norm' not in data self.N = N = data.num_nodes self.E = data.num_edges self.adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], value=data.edge_attr, sparse_sizes=(N, N)) self.data = copy.copy(data) self.data.edge_index = None self.data.edge_attr = None self.batch_size = batch_size self.num_steps = num_steps self.sample_coverage = sample_coverage self.num_workers = num_workers self.log = log self.__count__ = 0 if self.num_workers > 0: self.__sample_queue__ = Queue() self.__sample_workers__ = [] for _ in range(self.num_workers): worker = Process(target=self.__put_sample__, args=(self.__sample_queue__, )) worker.daemon = True worker.start() self.__sample_workers__.append(worker) path = osp.join(save_dir or '', self.__filename__) if save_dir is not None and osp.exists(path): # pragma: no cover self.node_norm, self.edge_norm = torch.load(path) else: self.node_norm, self.edge_norm = self.__compute_norm__() if save_dir is not None: # pragma: no cover torch.save((self.node_norm, self.edge_norm), path) if self.num_workers > 0: self.__data_queue__ = Queue() self.__data_workers__ = [] for _ in range(self.num_workers): worker = Process(target=self.__put_data__, args=(self.__data_queue__, )) worker.daemon = True worker.start() self.__data_workers__.append(worker) @property def __filename__(self): return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt' def __sample_nodes__(self, num_examples): raise NotImplementedError def __sample__(self, num_examples): node_samples = self.__sample_nodes__(num_examples) samples = [] for node_idx in node_samples: node_idx = node_idx.unique() adj, edge_idx = self.adj.saint_subgraph(node_idx) samples.append((node_idx, edge_idx, adj)) return samples def __compute_norm__(self): node_count = torch.zeros(self.N, dtype=torch.float) edge_count = torch.zeros(self.E, dtype=torch.float) if self.log: # pragma: no cover pbar = tqdm(total=self.N * self.sample_coverage) pbar.set_description('Compute GraphSAINT normalization') num_samples = total_sampled_nodes = 0 while total_sampled_nodes < self.N * self.sample_coverage: num_sampled_nodes = 0 if self.num_workers > 0: for _ in range(200): node_idx, edge_idx, _ = self.__sample_queue__.get() node_count[node_idx] += 1 edge_count[edge_idx] += 1 num_sampled_nodes += node_idx.size(0) else: samples = self.__sample__(200) for node_idx, edge_idx, _ in samples: node_count[node_idx] += 1 edge_count[edge_idx] += 1 num_sampled_nodes += node_idx.size(0) total_sampled_nodes += num_sampled_nodes num_samples += 200 if self.log: # pragma: no cover pbar.update(num_sampled_nodes) if self.log: # pragma: no cover pbar.close() row, col, _ = self.adj.coo() edge_norm = (node_count[col] / edge_count).clamp_(0, 1e4) edge_norm[torch.isnan(edge_norm)] = 0.1 node_count[node_count == 0] = 0.1 node_norm = num_samples / node_count / self.N return node_norm, edge_norm def __get_data_from_sample__(self, sample): node_idx, edge_idx, adj = sample data = self.data.__class__() data.num_nodes = node_idx.size(0) row, col, value = adj.coo() data.edge_index = torch.stack([row, col], dim=0) data.edge_attr = value for key, item in self.data: if item.size(0) == self.N: data[key] = item[node_idx] elif item.size(0) == self.E: data[key] = item[edge_idx] else: data[key] = item data.node_norm = self.node_norm[node_idx] data.edge_norm = self.edge_norm[edge_idx] return data def __put_sample__(self, queue): while True: sample = self.__sample__(1)[0] queue.put(sample) def __put_data__(self, queue): while True: sample = self.__sample_queue__.get() data = self.__get_data_from_sample__(sample) queue.put(data) def __next__(self): if self.__count__ < len(self): self.__count__ += 1 if self.num_workers > 0: data = self.__data_queue__.get() else: sample = self.__sample__(1)[0] data = self.__get_data_from_sample__(sample) return data else: raise StopIteration def __len__(self): return self.num_steps def __iter__(self): self.__count__ = 0 return self
[docs]class GraphSAINTNodeSampler(GraphSAINTSampler): r"""The GraphSAINT node sampler class (see :class:`torch_geometric.data.GraphSAINTSampler`). Args: batch_size (int): The number of nodes to sample per batch. """ def __sample_nodes__(self, num_examples): edge_sample = torch.randint(0, self.E, (num_examples, self.batch_size), dtype=torch.long) node_sample = self.adj.storage.row()[edge_sample] return node_sample.unbind(dim=0)
[docs]class GraphSAINTEdgeSampler(GraphSAINTSampler): r"""The GraphSAINT edge sampler class (see :class:`torch_geometric.data.GraphSAINTSampler`). Args: batch_size (int): The number of edges to sample per batch. """ def __sample_nodes__(self, num_examples): # This function corresponds to the `Edge2` sampler in the official # code repository that weights all edges as equally important. # This is the default configuration in the GraphSAINT implementation. edge_sample = torch.randint(0, self.E, (num_examples, self.batch_size), dtype=torch.long) source_node_sample = self.adj.storage.row()[edge_sample] target_node_sample = self.adj.storage.col()[edge_sample] node_sample = torch.cat([source_node_sample, target_node_sample], -1) return node_sample.unbind(dim=0)
[docs]class GraphSAINTRandomWalkSampler(GraphSAINTSampler): r"""The GraphSAINT random walk sampler class (see :class:`torch_geometric.data.GraphSAINTSampler`). Args: batch_size (int): The number of walks to sample per batch. walk_length (int): The length of each random walk. """ def __init__(self, data, batch_size, walk_length, num_steps=1, sample_coverage=50, save_dir=None, num_workers=0, log=True): self.walk_length = walk_length super(GraphSAINTRandomWalkSampler, self).__init__(data, batch_size, num_steps, sample_coverage, save_dir, num_workers, log) @property def __filename__(self): return (f'{self.__class__.__name__.lower()}_{self.walk_length}_' f'{self.sample_coverage}.pt') def __sample_nodes__(self, num_examples): start = torch.randint(0, self.N, (num_examples, self.batch_size), dtype=torch.long) node_sample = self.adj.random_walk(start.flatten(), self.walk_length) node_sample = node_sample.view( num_examples, self.batch_size * (self.walk_length + 1)) return node_sample.unbind(dim=0)