from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch_geometric.distributed import (
DistContext,
DistLoader,
DistNeighborSampler,
LocalFeatureStore,
LocalGraphStore,
)
from torch_geometric.loader import NodeLoader
from torch_geometric.sampler.base import SubgraphType
from torch_geometric.typing import EdgeType, InputNodes, OptTensor
[docs]class DistNeighborLoader(NodeLoader, DistLoader):
r"""A distributed loader that performs sampling from nodes.
Args:
data (tuple): A (:class:`~torch_geometric.data.FeatureStore`,
:class:`~torch_geometric.data.GraphStore`) data object.
num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]):
The number of neighbors to sample for each node in each iteration.
If an entry is set to :obj:`-1`, all neighbors will be included.
In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for each individual edge type.
master_addr (str): RPC address for distributed loader communication,
*i.e.* the IP address of the master node.
master_port (Union[int, str]): Open port for RPC communication with
the master node.
current_ctx (DistContext): Distributed context information of the
current process.
concurrency (int, optional): RPC concurrency used for defining the
maximum size of the asynchronous processing queue.
(default: :obj:`1`)
All other arguments follow the interface of
:class:`torch_geometric.loader.NeighborLoader`.
"""
def __init__(
self,
data: Tuple[LocalFeatureStore, LocalGraphStore],
num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],
master_addr: str,
master_port: Union[int, str],
current_ctx: DistContext,
input_nodes: InputNodes = None,
input_time: OptTensor = None,
dist_sampler: Optional[DistNeighborSampler] = None,
replace: bool = False,
subgraph_type: Union[SubgraphType, str] = "directional",
disjoint: bool = False,
temporal_strategy: str = "uniform",
time_attr: Optional[str] = None,
transform: Optional[Callable] = None,
concurrency: int = 1,
num_rpc_threads: int = 16,
filter_per_worker: Optional[bool] = False,
async_sampling: bool = True,
device: Optional[torch.device] = None,
**kwargs,
):
assert isinstance(data[0], LocalFeatureStore)
assert isinstance(data[1], LocalGraphStore)
assert concurrency >= 1, "RPC concurrency must be greater than 1"
if input_time is not None and time_attr is None:
raise ValueError("Received conflicting 'input_time' and "
"'time_attr' arguments: 'input_time' is set "
"while 'time_attr' is not set.")
channel = torch.multiprocessing.Queue() if async_sampling else None
if dist_sampler is None:
dist_sampler = DistNeighborSampler(
data=data,
current_ctx=current_ctx,
num_neighbors=num_neighbors,
replace=replace,
subgraph_type=subgraph_type,
disjoint=disjoint,
temporal_strategy=temporal_strategy,
time_attr=time_attr,
device=device,
channel=channel,
concurrency=concurrency,
)
DistLoader.__init__(
self,
channel=channel,
master_addr=master_addr,
master_port=master_port,
current_ctx=current_ctx,
dist_sampler=dist_sampler,
num_rpc_threads=num_rpc_threads,
**kwargs,
)
NodeLoader.__init__(
self,
data=data,
node_sampler=dist_sampler,
input_nodes=input_nodes,
input_time=input_time,
transform=transform,
filter_per_worker=filter_per_worker,
transform_sampler_output=self.channel_get if channel else None,
worker_init_fn=self.worker_init_fn,
**kwargs,
)
def __repr__(self) -> str:
return DistLoader.__repr__(self)