from typing import Callable, Optional, Tuple, Union
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader.link_loader import LinkLoader
from torch_geometric.loader.utils import get_edge_label_index
from torch_geometric.sampler import NeighborSampler
from torch_geometric.sampler.base import NegativeSamplingConfig
from torch_geometric.typing import InputEdges, NumNeighbors, OptTensor
[docs]class LinkNeighborLoader(LinkLoader):
r"""A link-based data loader derived as an extension of the node-based
:class:`torch_geometric.loader.NeighborLoader`.
This loader allows for mini-batch training of GNNs on large-scale graphs
where full-batch training is not feasible.
More specifically, this loader first selects a sample of edges from the
set of input edges :obj:`edge_label_index` (which may or not be edges in
the original graph) and then constructs a subgraph from all the nodes
present in this list by sampling :obj:`num_neighbors` neighbors in each
iteration.
.. code-block:: python
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import LinkNeighborLoader
data = Planetoid(path, name='Cora')[0]
loader = LinkNeighborLoader(
data,
# Sample 30 neighbors for each node for 2 iterations
num_neighbors=[30] * 2,
# Use a batch size of 128 for sampling training nodes
batch_size=128,
edge_label_index=data.edge_index,
)
sampled_data = next(iter(loader))
print(sampled_data)
>>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368],
train_mask=[1368], val_mask=[1368], test_mask=[1368],
edge_label_index=[2, 128])
It is additionally possible to provide edge labels for sampled edges, which
are then added to the batch:
.. code-block:: python
loader = LinkNeighborLoader(
data,
num_neighbors=[30] * 2,
batch_size=128,
edge_label_index=data.edge_index,
edge_label=torch.ones(data.edge_index.size(1))
)
sampled_data = next(iter(loader))
print(sampled_data)
>>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368],
train_mask=[1368], val_mask=[1368], test_mask=[1368],
edge_label_index=[2, 128], edge_label=[128])
The rest of the functionality mirrors that of
:class:`~torch_geometric.loader.NeighborLoader`, including support for
heterogenous graphs.
.. note::
Negative sampling is currently implemented in an approximate
way, *i.e.* negative edges may contain false negatives.
Args:
data (torch_geometric.data.Data or torch_geometric.data.HeteroData):
The :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` graph object.
num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The
number of neighbors to sample for each node in each iteration.
In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for each individual edge type.
If an entry is set to :obj:`-1`, all neighbors will be included.
edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The edge indices for which neighbors are sampled to create
mini-batches.
If set to :obj:`None`, all edges will be considered.
In heterogeneous graphs, needs to be passed as a tuple that holds
the edge type and corresponding edge indices.
(default: :obj:`None`)
edge_label (Tensor, optional): The labels of edge indices for
which neighbors are sampled. Must be the same length as
the :obj:`edge_label_index`. If set to :obj:`None` its set to
`torch.zeros(...)` internally. (default: :obj:`None`)
edge_label_time (Tensor, optional): The timestamps for edge indices
for which neighbors are sampled. Must be the same length as
:obj:`edge_label_index`. If set, temporal sampling will be
used such that neighbors are guaranteed to fulfill temporal
constraints, *i.e.*, neighbors have an earlier timestamp than
the ouput edge. The :obj:`time_attr` needs to be set for this
to work. (default: :obj:`None`)
replace (bool, optional): If set to :obj:`True`, will sample with
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
disjoint (bool, optional): If set to :obj: `True`, each seed node will
create its own disjoint subgraph.
If set to :obj:`True`, mini-batch outputs will have a :obj:`batch`
vector holding the mapping of nodes to their respective subgraph.
Will get automatically set to :obj:`True` in case of temporal
sampling. (default: :obj:`False`)
temporal_strategy (string, optional): The sampling strategy when using
temporal sampling (:obj:`"uniform"`, :obj:`"last"`).
If set to :obj:`"uniform"`, will sample uniformly across neighbors
that fulfill temporal constraints.
If set to :obj:`"last"`, will sample the last `num_neighbors` that
fulfill temporal constraints.
(default: :obj:`"uniform"`)
neg_sampling (NegativeSamplingConfig, optional): The negative sampling
strategy. Can be either :obj:`"binary"` or :obj:`"triplet"`, and
can be further customized by an additional :obj:`amount` argument
to control the ratio of sampled negatives to positive edges.
If set to :obj:`"binary"`, will randomly sample negative links
from the graph.
In case :obj:`edge_label` does not exist, it will be automatically
created and represents a binary classification task (:obj:`0` =
negative edge, :obj:`1` = positive edge).
In case :obj:`edge_label` does exist, it has to be a categorical
label from :obj:`0` to :obj:`num_classes - 1`.
After negative sampling, label :obj:`0` represents negative edges,
and labels :obj:`1` to :obj:`num_classes` represent the labels of
positive edges.
Note that returned labels are of type :obj:`torch.float` for binary
classification (to facilitate the ease-of-use of
:meth:`F.binary_cross_entropy`) and of type
:obj:`torch.long` for multi-class classification (to facilitate the
ease-of-use of :meth:`F.cross_entropy`).
If set to :obj:`"triplet"`, will randomly sample negative
destination nodes for each positive source node.
Samples can be accessed via the attributes :obj:`src_index`,
:obj:`dst_pos_index` and :obj:`dst_neg_index` in the respective
node types of the returned mini-batch.
:obj:`edge_label` needs to be :obj:`None` for
:obj:`"triplet"`-based negative sampling.
If set to :obj:`None`, no negative sampling strategy is applied.
(default: :obj:`None`)
neg_sampling_ratio (int or float, optional): The ratio of sampled
negative edges to the number of positive edges.
Deprecated in favor of the :obj:`neg_sampling` argument.
(default: :obj:`None`)
time_attr (str, optional): The name of the attribute that denotes
timestamps for the nodes in the graph. Only used if
:obj:`edge_label_time` is set. (default: :obj:`None`)
transform (Callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
is_sorted (bool, optional): If set to :obj:`True`, assumes that
:obj:`edge_index` is sorted by column.
If :obj:`time_attr` is set, additionally requires that rows are
sorted according to time within individual neighborhoods.
This avoids internal re-sorting of the data and can improve
runtime and memory efficiency. (default: :obj:`False`)
filter_per_worker (bool, optional): If set to :obj:`True`, will filter
the returning data in each worker's subprocess rather than in the
main process.
Setting this to :obj:`True` is generally not recommended:
(1) it may result in too many open file handles,
(2) it may slown down data loading,
(3) it requires operating on CPU tensors.
(default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
"""
def __init__(
self,
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
num_neighbors: NumNeighbors,
edge_label_index: InputEdges = None,
edge_label: OptTensor = None,
edge_label_time: OptTensor = None,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
neg_sampling: Optional[NegativeSamplingConfig] = None,
neg_sampling_ratio: Optional[Union[int, float]] = None,
time_attr: Optional[str] = None,
transform: Callable = None,
is_sorted: bool = False,
filter_per_worker: bool = False,
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
):
# TODO(manan): Avoid duplicated computation (here and in NodeLoader):
edge_type, _ = get_edge_label_index(data, edge_label_index)
if (edge_label_time is not None) != (time_attr is not None):
raise ValueError(
f"Received conflicting 'edge_label_time' and 'time_attr' "
f"arguments: 'edge_label_time' is "
f"{'set' if edge_label_time is not None else 'not set'} "
f"while 'input_time' is "
f"{'set' if time_attr is not None else 'not set'}.")
if neighbor_sampler is None:
neighbor_sampler = NeighborSampler(
data,
num_neighbors=num_neighbors,
replace=replace,
directed=directed,
disjoint=disjoint,
temporal_strategy=temporal_strategy,
input_type=edge_type,
time_attr=time_attr,
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
)
super().__init__(
data=data,
link_sampler=neighbor_sampler,
edge_label_index=edge_label_index,
edge_label=edge_label,
edge_label_time=edge_label_time,
neg_sampling=neg_sampling,
neg_sampling_ratio=neg_sampling_ratio,
transform=transform,
filter_per_worker=filter_per_worker,
**kwargs,
)