torch_geometric.distributed
Context information of the current process. |
|
Implements the |
|
Implements the |
|
Partitions the graph and its features of a |
|
An implementation of a distributed and asynchronised neighbor sampler used by |
|
A base class for creating distributed data loading routines. |
|
A distributed loader that performs sampling from nodes. |
|
A distributed loader that performs sampling from edges. |
- class DistContext(rank: int, global_rank: int, world_size: int, global_world_size: int, group_name: str, role: DistRole = DistRole.WORKER)[source]
Context information of the current process.
- class LocalFeatureStore[source]
Implements the
FeatureStore
interface to act as a local feature store for distributed training.- lookup_features(index: Tensor, is_node_feat: bool = True, input_type: Optional[Union[str, Tuple[str, str, str]]] = None) Future [source]
Lookup of local/remote features.
- classmethod from_data(node_id: Tensor, x: Optional[Tensor] = None, y: Optional[Tensor] = None, edge_id: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None) LocalFeatureStore [source]
Creates a local feature store from homogeneous PyG tensors.
- Parameters:
node_id (torch.Tensor) – The global identifier for every local node.
x (torch.Tensor, optional) – The node features. (default:
None
)y (torch.Tensor, optional) – The node labels. (default:
None
)edge_id (torch.Tensor, optional) – The global identifier for every local edge. (default:
None
)edge_attr (torch.Tensor, optional) – The edge features. (default:
None
)
- classmethod from_hetero_data(node_id_dict: Dict[str, Tensor], x_dict: Optional[Dict[str, Tensor]] = None, y_dict: Optional[Dict[str, Tensor]] = None, edge_id_dict: Optional[Dict[Tuple[str, str, str], Tensor]] = None, edge_attr_dict: Optional[Dict[Tuple[str, str, str], Tensor]] = None) LocalFeatureStore [source]
Creates a local graph store from heterogeneous PyG tensors.
- Parameters:
node_id_dict (Dict[NodeType, torch.Tensor]) – The global identifier for every local node of every node type.
x_dict (Dict[NodeType, torch.Tensor], optional) – The node features of every node type. (default:
None
)y_dict (Dict[NodeType, torch.Tensor], optional) – The node labels of every node type. (default:
None
)edge_id_dict (Dict[EdgeType, torch.Tensor], optional) – The global identifier for every local edge of every edge types. (default:
None
)edge_attr_dict (Dict[EdgeType, torch.Tensor], optional) – The edge features of every edge type. (default:
None
)
- class LocalGraphStore[source]
Implements the
GraphStore
interface to act as a local graph store for distributed training.- get_partition_ids_from_nids(ids: Tensor, node_type: Optional[str] = None) Tensor [source]
Returns the partition IDs of node IDs for a specific node type.
- get_partition_ids_from_eids(eids: Tensor, edge_type: Optional[Tuple[str, str, str]] = None)[source]
Returns the partition IDs of edge IDs for a specific edge type.
- classmethod from_data(edge_id: Tensor, edge_index: Tensor, num_nodes: int, is_sorted: bool = False) LocalGraphStore [source]
Creates a local graph store from a homogeneous or heterogenous PyG graph.
- Parameters:
edge_id (torch.Tensor) – The global identifier for every local edge.
edge_index (torch.Tensor) – The local edge indices.
num_nodes (int) – The number of nodes in the local graph.
is_sorted (bool) – Whether edges are sorted by column/destination nodes (CSC format). (default:
False
)
- classmethod from_hetero_data(edge_id_dict: Dict[Tuple[str, str, str], Tensor], edge_index_dict: Dict[Tuple[str, str, str], Tensor], num_nodes_dict: Dict[str, int], is_sorted: bool = False) LocalGraphStore [source]
Creates a local graph store from a heterogeneous PyG graph.
- Parameters:
edge_id_dict (Dict[EdgeType, torch.Tensor]) – The global identifier for every local edge of every edge type.
edge_index_dict (Dict[EdgeType, torch.Tensor]) – The local edge indices of every edge type.
num_nodes_dict – (Dict[str, int]): The number of nodes for every node type.
is_sorted (bool) – Whether edges are sorted by column/destination nodes (CSC format). (default:
False
)
- class Partitioner(data: Union[Data, HeteroData], num_parts: int, root: str, recursive: bool = False)[source]
Partitions the graph and its features of a
Data
orHeteroData
object.Partitioned data output will be structured as shown below.
Homogeneous graphs:
root/ |-- META.json |-- node_map.pt |-- edge_map.pt |-- part0/ |-- graph.pt |-- node_feats.pt |-- edge_feats.pt |-- part1/ |-- graph.pt |-- node_feats.pt |-- edge_feats.pt
Heterogeneous graphs:
root/ |-- META.json |-- node_map/ |-- ntype1.pt |-- ntype2.pt |-- edge_map/ |-- etype1.pt |-- etype2.pt |-- part0/ |-- graph.pt |-- node_feats.pt |-- edge_feats.pt |-- part1/ |-- graph.pt |-- node_feats.pt |-- edge_feats.pt
- Parameters:
data (Data or HeteroData) – The data object.
num_parts (int) – The number of partitions.
recursive (bool, optional) – If set to
True
, will use multilevel recursive bisection instead of multilevel k-way partitioning. (default:False
)root (str) – Root directory where the partitioned dataset should be saved.
- class DistNeighborSampler(current_ctx: DistContext, data: Tuple[LocalFeatureStore, LocalGraphStore], num_neighbors: Union[NumNeighbors, List[int], Dict[Tuple[str, str, str], List[int]]], channel: Optional[Queue] = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = 'directional', disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, concurrency: int = 1, device: Optional[device] = None, **kwargs)[source]
An implementation of a distributed and asynchronised neighbor sampler used by
DistNeighborLoader
andDistLinkNeighborLoader
.- async node_sample(inputs: Union[NodeSamplerInput, DistEdgeHeteroSamplerInput]) Union[SamplerOutput, HeteroSamplerOutput] [source]
Performs layer-by-layer distributed sampling from a
NodeSamplerInput
orDistEdgeHeteroSamplerInput
and returns the output of the sampling procedure.Note
In case of distributed training it is required to synchronize the results between machines after each layer.
- async edge_sample(inputs: EdgeSamplerInput, sample_fn: Callable, num_nodes: Union[int, Dict[str, int]], disjoint: bool, node_time: Optional[Union[Tensor, Dict[str, Tensor]]] = None, neg_sampling: Optional[NegativeSampling] = None) Union[SamplerOutput, HeteroSamplerOutput] [source]
Performs layer-by-layer distributed sampling from an
EdgeSamplerInput
and returns the output of the sampling procedure.Note
In case of distributed training it is required to synchronize the results between machines after each layer.
- async sample_one_hop(srcs: Tensor, one_hop_num: int, seed_time: Optional[Tensor] = None, src_batch: Optional[Tensor] = None, edge_type: Optional[Tuple[str, str, str]] = None) SamplerOutput [source]
Samples one-hop neighbors for a set of seed nodes in
srcs
. If seed nodes are located on a local partition, evaluates the sampling function on the current machine. If seed nodes are from a remote partition, sends a request to a remote machine that contains this partition.
- class DistLoader(current_ctx: DistContext, master_addr: Optional[str] = None, master_port: Optional[Union[int, str]] = None, channel: Optional[Queue] = None, num_rpc_threads: int = 16, rpc_timeout: int = 180, dist_sampler: Optional[DistNeighborSampler] = None, **kwargs)[source]
A base class for creating distributed data loading routines.
- Parameters:
current_ctx (DistContext) – Distributed context info of the current process.
master_addr (str, optional) – RPC address for distributed loader communication. Refers to the IP address of the master node. (default:
None
)master_port (int or str, optional) – The open port for RPC communication with the master node. (default:
None
)channel (mp.Queue, optional) – A communication channel for messages. (default:
None
)num_rpc_threads (int, optional) – The number of threads in the thread-pool used by
TensorPipeAgent
to execute requests. (default:16
)rpc_timeout (int, optional) – The default timeout in seconds for RPC requests. If the RPC has not completed in this timeframe, an exception will be raised. Callers can override this timeout for individual RPCs in
rpc_sync()
andrpc_async()
if necessary. (default:180
)
- class DistNeighborLoader(data: Tuple[LocalFeatureStore, LocalGraphStore], num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], master_addr: str, master_port: Union[int, str], current_ctx: DistContext, input_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_time: Optional[Tensor] = None, dist_sampler: Optional[DistNeighborSampler] = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = 'directional', disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, transform: Optional[Callable] = None, concurrency: int = 1, num_rpc_threads: int = 16, filter_per_worker: Optional[bool] = False, async_sampling: bool = True, device: Optional[device] = None, **kwargs)[source]
A distributed loader that performs sampling from nodes.
- Parameters:
data (tuple) – A (
FeatureStore
,GraphStore
) data object.num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]) – The number of neighbors to sample for each node in each iteration. If an entry is set to
-1
, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type.master_addr (str) – RPC address for distributed loader communication, i.e. the IP address of the master node.
master_port (Union[int, str]) – Open port for RPC communication with the master node.
current_ctx (DistContext) – Distributed context information of the current process.
concurrency (int, optional) – RPC concurrency used for defining the maximum size of the asynchronous processing queue. (default:
1
)
All other arguments follow the interface of
torch_geometric.loader.NeighborLoader
.
- class DistLinkNeighborLoader(data: Tuple[LocalFeatureStore, LocalGraphStore], num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], master_addr: str, master_port: Union[int, str], current_ctx: DistContext, edge_label_index: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, edge_label: Optional[Tensor] = None, edge_label_time: Optional[Tensor] = None, dist_sampler: Optional[DistNeighborSampler] = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = 'directional', disjoint: bool = False, temporal_strategy: str = 'uniform', neg_sampling: Optional[NegativeSampling] = None, neg_sampling_ratio: Optional[Union[int, float]] = None, time_attr: Optional[str] = None, transform: Optional[Callable] = None, concurrency: int = 1, num_rpc_threads: int = 16, filter_per_worker: Optional[bool] = False, async_sampling: bool = True, device: Optional[device] = None, **kwargs)[source]
A distributed loader that performs sampling from edges.
- Parameters:
data (tuple) – A (
FeatureStore
,GraphStore
) data object.num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]) – The number of neighbors to sample for each node in each iteration. If an entry is set to
-1
, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type.master_addr (str) – RPC address for distributed loader communication, i.e. the IP address of the master node.
master_port (Union[int, str]) – Open port for RPC communication with the master node.
current_ctx (DistContext) – Distributed context information of the current process.
concurrency (int, optional) – RPC concurrency used for defining the maximum size of the asynchronous processing queue. (default:
1
)
All other arguments follow the interface of
torch_geometric.loader.LinkNeighborLoader
.