from typing import Union, List, Dict, Callable, Optional, Any
from torch_geometric.typing import EdgeType, InputNodes
from collections.abc import Sequence
import torch
from torch import Tensor
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader.base import BaseDataLoader
from torch_geometric.loader.utils import edge_type_to_str
from torch_geometric.loader.utils import to_csc, to_hetero_csc
from torch_geometric.loader.utils import filter_data, filter_hetero_data
NumNeighbors = Union[List[int], Dict[EdgeType, List[int]]]
class NeighborSampler:
def __init__(
self,
data: Union[Data, HeteroData],
num_neighbors: NumNeighbors,
replace: bool = False,
directed: bool = True,
input_node_type: Optional[str] = None,
):
self.data_cls = data.__class__
self.num_neighbors = num_neighbors
self.replace = replace
self.directed = directed
if isinstance(data, Data):
# Convert the graph data into a suitable format for sampling.
self.colptr, self.row, self.perm = to_csc(data, device='cpu')
assert isinstance(num_neighbors, (list, tuple))
elif isinstance(data, HeteroData):
# Convert the graph data into a suitable format for sampling.
# NOTE: Since C++ cannot take dictionaries with tuples as key as
# input, edge type triplets are converted into single strings.
out = to_hetero_csc(data, device='cpu')
self.colptr_dict, self.row_dict, self.perm_dict = out
self.node_types, self.edge_types = data.metadata()
if isinstance(num_neighbors, (list, tuple)):
num_neighbors = {key: num_neighbors for key in self.edge_types}
assert isinstance(num_neighbors, dict)
self.num_neighbors = {
edge_type_to_str(key): value
for key, value in num_neighbors.items()
}
self.num_hops = max([len(v) for v in self.num_neighbors.values()])
assert isinstance(input_node_type, str)
self.input_node_type = input_node_type
else:
raise TypeError(f'NeighborLoader found invalid type: {type(data)}')
def __call__(self, indices: List[int]):
if not isinstance(indices, Tensor):
index = torch.tensor(indices)
assert index.dtype == torch.int64
if issubclass(self.data_cls, Data):
sample_fn = torch.ops.torch_sparse.neighbor_sample
node, row, col, edge = sample_fn(
self.colptr,
self.row,
index,
self.num_neighbors,
self.replace,
self.directed,
)
return node, row, col, edge, index.numel()
elif issubclass(self.data_cls, HeteroData):
sample_fn = torch.ops.torch_sparse.hetero_neighbor_sample
node_dict, row_dict, col_dict, edge_dict = sample_fn(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
{self.input_node_type: index},
self.num_neighbors,
self.num_hops,
self.replace,
self.directed,
)
return node_dict, row_dict, col_dict, edge_dict, index.numel()
[docs]class NeighborLoader(BaseDataLoader):
r"""A data loader that performs neighbor sampling as introduced in the
`"Inductive Representation Learning on Large Graphs"
<https://arxiv.org/abs/1706.02216>`_ paper.
This loader allows for mini-batch training of GNNs on large-scale graphs
where full-batch training is not feasible.
More specifically, :obj:`num_neighbors` denotes how much neighbors are
sampled for each node in each iteration.
:class:`~torch_geometric.loader.NeighborLoader` takes in this list of
:obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for
each node involved in iteration :obj:`i - 1`.
Sampled nodes are sorted based on the order in which they were sampled.
In particular, the first :obj:`batch_size` nodes represent the set of
original mini-batch nodes.
.. code-block:: python
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
data = Planetoid(path, name='Cora')[0]
loader = NeighborLoader(
data,
# Sample 30 neighbors for each node for 2 iterations
num_neighbors=[30] * 2,
# Use a batch size of 128 for sampling training nodes
batch_size=128,
input_nodes=data.train_mask,
)
sampled_data = next(iter(loader))
print(sampled_data.batch_size)
>>> 128
By default, the data loader will only include the edges that were
originally sampled (:obj:`directed = True`).
This option should only be used in case the number of hops is equivalent to
the number of GNN layers.
In case the number of GNN layers is greater than the number of hops,
consider setting :obj:`directed = False`, which will include all edges
between all sampled nodes (but is slightly slower as a result).
Furthermore, :class:`~torch_geometric.loader.NeighborLoader` works for both
**homogeneous** graphs stored via :class:`~torch_geometric.data.Data` as
well as **heterogeneous** graphs stored via
:class:`~torch_geometric.data.HeteroData`.
When operating in heterogeneous graphs, more fine-grained control over
the amount of sampled neighbors of individual edge types is possible, but
not necessary:
.. code-block:: python
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader
hetero_data = OGB_MAG(path)[0]
loader = NeighborLoader(
hetero_data,
# Sample 30 neighbors for each node and edge type for 2 iterations
num_neighbors={key: [30] * 2 for key in hetero_data.edge_types},
# Use a batch size of 128 for sampling training nodes of type paper
batch_size=128,
input_nodes=('paper', hetero_data['paper'].train_mask),
)
sampled_hetero_data = next(iter(loader))
print(sampled_hetero_data['paper'].batch_size)
>>> 128
.. note::
For an example of using
:class:`~torch_geometric.loader.NeighborLoader`, see
`examples/hetero/to_hetero_mag.py <https://github.com/pyg-team/
pytorch_geometric/blob/master/examples/hetero/to_hetero_mag.py>`_.
Args:
data (torch_geometric.data.Data or torch_geometric.data.HeteroData):
The :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` graph object.
num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The
number of neighbors to sample for each node in each iteration.
In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for each individual edge type.
If an entry is set to :obj:`-1`, all neighbors will be included.
input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]): The
indices of nodes for which neighbors are sampled to create
mini-batches.
Needs to be either given as a :obj:`torch.LongTensor` or
:obj:`torch.BoolTensor`.
If set to :obj:`None`, all nodes will be considered.
In heterogeneous graphs, needs to be passed as a tuple that holds
the node type and node indices. (default: :obj:`None`)
replace (bool, optional): If set to :obj:`True`, will sample with
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
transform (Callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
"""
def __init__(
self,
data: Union[Data, HeteroData],
num_neighbors: NumNeighbors,
input_nodes: InputNodes = None,
replace: bool = False,
directed: bool = True,
transform: Callable = None,
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
):
if 'dataset' in kwargs:
del kwargs['dataset']
if 'collate_fn' in kwargs:
del kwargs['collate_fn']
# Save for PyTorch Lightning:
self.data = data
self.num_neighbors = num_neighbors
self.input_nodes = input_nodes
self.replace = replace
self.directed = directed
self.transform = transform
self.neighbor_sampler = neighbor_sampler
if neighbor_sampler is None:
input_node_type = get_input_node_type(input_nodes)
self.neighbor_sampler = NeighborSampler(data, num_neighbors,
replace, directed,
input_node_type)
return super().__init__(get_input_node_indices(self.data, input_nodes),
collate_fn=self.neighbor_sampler, **kwargs)
def transform_fn(self, out: Any) -> Union[Data, HeteroData]:
if isinstance(self.data, Data):
node, row, col, edge, batch_size = out
data = filter_data(self.data, node, row, col, edge,
self.neighbor_sampler.perm)
data.batch_size = batch_size
elif isinstance(self.data, HeteroData):
node_dict, row_dict, col_dict, edge_dict, batch_size = out
data = filter_hetero_data(self.data, node_dict, row_dict, col_dict,
edge_dict,
self.neighbor_sampler.perm_dict)
data[self.neighbor_sampler.input_node_type].batch_size = batch_size
return data if self.transform is None else self.transform(data)
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
###############################################################################
def get_input_node_type(input_nodes: InputNodes) -> Optional[str]:
if isinstance(input_nodes, str):
return input_nodes
if isinstance(input_nodes, (list, tuple)):
assert isinstance(input_nodes[0], str)
return input_nodes[0]
return None
def get_input_node_indices(data: Union[Data, HeteroData],
input_nodes: InputNodes) -> Sequence:
if isinstance(data, Data) and input_nodes is None:
return range(data.num_nodes)
if isinstance(data, HeteroData):
if isinstance(input_nodes, str):
input_nodes = (input_nodes, None)
assert isinstance(input_nodes, (list, tuple))
assert len(input_nodes) == 2
assert isinstance(input_nodes[0], str)
if input_nodes[1] is None:
return range(data[input_nodes[0]].num_nodes)
input_nodes = input_nodes[1]
if isinstance(input_nodes, Tensor):
if input_nodes.dtype == torch.bool:
input_nodes = input_nodes.nonzero(as_tuple=False).view(-1)
input_nodes = input_nodes.tolist()
assert isinstance(input_nodes, Sequence)
return input_nodes