Source code for torch_geometric.distributed.local_feature_store

import copy
import os.path as osp
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.data import FeatureStore, TensorAttr
from torch_geometric.data.feature_store import _FieldStatus
from torch_geometric.distributed.partition import load_partition_info
from torch_geometric.distributed.rpc import (
    RPCCallBase,
    RPCRouter,
    rpc_async,
    rpc_register,
)
from torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType


class RPCCallFeatureLookup(RPCCallBase):
    r"""A wrapper for RPC calls to the feature store."""
    def __init__(self, dist_feature: FeatureStore):
        super().__init__()
        self.dist_feature = dist_feature

    def rpc_async(self, *args, **kwargs):
        return self.dist_feature._rpc_local_feature_get(*args, **kwargs)

    def rpc_sync(self, *args, **kwargs):
        raise NotImplementedError


@dataclass
class LocalTensorAttr(TensorAttr):
    r"""Tensor attribute for storing features without :obj:`index`."""
    def __init__(
        self,
        group_name: Optional[Union[NodeType, EdgeType]] = _FieldStatus.UNSET,
        attr_name: Optional[str] = _FieldStatus.UNSET,
        index=None,
    ):
        super().__init__(group_name, attr_name, index)


[docs]class LocalFeatureStore(FeatureStore): r"""Implements the :class:`~torch_geometric.data.FeatureStore` interface to act as a local feature store for distributed training. """ def __init__(self): super().__init__(tensor_attr_cls=LocalTensorAttr) self._feat: Dict[Tuple[Union[NodeType, EdgeType], str], Tensor] = {} # Save the global node/edge IDs: self._global_id: Dict[Union[NodeType, EdgeType], Tensor] = {} # Save the mapping from global node/edge IDs to indices in `_feat`: self._global_id_to_index: Dict[Union[NodeType, EdgeType], Tensor] = {} # For partition/RPC info related to distributed features: self.num_partitions: int = 1 self.partition_idx: int = 0 # Mapping between node ID and partition ID: self.node_feat_pb: Union[Tensor, Dict[NodeType, Tensor]] # Mapping between edge ID and partition ID: self.edge_feat_pb: Union[Tensor, Dict[EdgeType, Tensor]] # Node labels: self.labels: Optional[Tensor] = None self.local_only: bool = False self.rpc_router: Optional[RPCRouter] = None self.meta: Optional[Dict] = None self.rpc_call_id: Optional[int] = None @staticmethod def key(attr: TensorAttr) -> Tuple[str, str]: return (attr.group_name, attr.attr_name) def put_global_id( self, global_id: Tensor, group_name: Union[NodeType, EdgeType], ) -> bool: self._global_id[group_name] = global_id self._set_global_id_to_index(group_name) return True def get_global_id( self, group_name: Union[NodeType, EdgeType], ) -> Optional[Tensor]: return self._global_id.get(group_name) def remove_global_id(self, group_name: Union[NodeType, EdgeType]) -> bool: return self._global_id.pop(group_name) is not None def _set_global_id_to_index(self, group_name: Union[NodeType, EdgeType]): global_id = self.get_global_id(group_name) if global_id is None: return # TODO Compute this mapping without materializing a full-sized tensor: global_id_to_index = global_id.new_full((int(global_id.max()) + 1, ), fill_value=-1) global_id_to_index[global_id] = torch.arange(global_id.numel()) self._global_id_to_index[group_name] = global_id_to_index def _put_tensor(self, tensor: Tensor, attr: TensorAttr) -> bool: assert attr.index is None self._feat[self.key(attr)] = tensor return True def _get_tensor(self, attr: TensorAttr) -> Optional[Tensor]: tensor = self._feat.get(self.key(attr)) if tensor is None: return None if attr.index is None: # Empty indices return the full tensor: return tensor return tensor[attr.index] def _remove_tensor(self, attr: TensorAttr) -> bool: assert attr.index is None return self._feat.pop(self.key(attr), None) is not None def get_tensor_from_global_id(self, *args, **kwargs) -> Optional[Tensor]: attr = self._tensor_attr_cls.cast(*args, **kwargs) assert attr.index is not None attr = copy.copy(attr) attr.index = self._global_id_to_index[attr.group_name][attr.index] return self.get_tensor(attr) def _get_tensor_size(self, attr: TensorAttr) -> Tuple[int, ...]: return self._get_tensor(attr).size()
[docs] def get_all_tensor_attrs(self) -> List[LocalTensorAttr]: return [self._tensor_attr_cls.cast(*key) for key in self._feat.keys()]
def set_rpc_router(self, rpc_router: RPCRouter): self.rpc_router = rpc_router if not self.local_only: if self.rpc_router is None: raise ValueError("An RPC router must be provided") rpc_call = RPCCallFeatureLookup(self) self.rpc_call_id = rpc_register(rpc_call) else: self.rpc_call_id = None def has_edge_attr(self) -> bool: has_edge_attr = False for k in [key for key in self._feat.keys() if 'edge_attr' in key]: try: self.get_tensor(k[0], 'edge_attr') has_edge_attr = True except KeyError: pass return has_edge_attr
[docs] def lookup_features( self, index: Tensor, is_node_feat: bool = True, input_type: Optional[NodeOrEdgeType] = None, ) -> torch.futures.Future: r"""Lookup of local/remote features.""" remote_fut = self._remote_lookup_features(index, is_node_feat, input_type) local_feature = self._local_lookup_features(index, is_node_feat, input_type) res_fut = torch.futures.Future() def when_finish(*_): try: remote_feature_list = remote_fut.wait() # combine the feature from remote and local result = torch.zeros( index.size(0), local_feature[0].size(1), dtype=local_feature[0].dtype, ) result[local_feature[1]] = local_feature[0] for remote in remote_feature_list: result[remote[1]] = remote[0] except Exception as e: res_fut.set_exception(e) else: res_fut.set_result(result) remote_fut.add_done_callback(when_finish) return res_fut
def _local_lookup_features( self, index: Tensor, is_node_feat: bool = True, input_type: Optional[Union[NodeType, EdgeType]] = None, ) -> Tuple[Tensor, Tensor]: r"""Lookup the features in local nodes based on node/edge IDs.""" pb = self.node_feat_pb if is_node_feat else self.edge_feat_pb input_order = torch.arange(index.size(0), dtype=torch.long) if self.meta['is_hetero']: partition_ids = pb[input_type][index] else: partition_ids = pb[index] local_mask = partition_ids == self.partition_idx local_ids = torch.masked_select(index, local_mask) local_index = torch.masked_select(input_order, local_mask) if self.meta['is_hetero']: if is_node_feat: kwargs = dict(group_name=input_type, attr_name='x') ret_feat = self.get_tensor_from_global_id( index=local_ids, **kwargs) else: kwargs = dict(group_name=input_type, attr_name='edge_attr') ret_feat = self.get_tensor_from_global_id( index=local_ids, **kwargs) else: if is_node_feat: kwargs = dict(group_name=None, attr_name='x') ret_feat = self.get_tensor_from_global_id( index=local_ids, **kwargs) else: kwargs = dict(group_name=(None, None), attr_name='edge_attr') ret_feat = self.get_tensor_from_global_id( index=local_ids, **kwargs) return ret_feat, local_index def _remote_lookup_features( self, index: Tensor, is_node_feat: bool = True, input_type: Optional[Union[NodeType, EdgeType]] = None, ) -> torch.futures.Future: r"""Fetch the remote features with the remote node/edge IDs.""" pb = self.node_feat_pb if is_node_feat else self.edge_feat_pb input_order = torch.arange(index.size(0), dtype=torch.long) if self.meta['is_hetero']: partition_ids = pb[input_type][index] else: partition_ids = pb[index] futs, indexes = [], [] for pidx in range(0, self.num_partitions): if pidx == self.partition_idx: continue remote_mask = partition_ids == pidx remote_ids = index[remote_mask] if remote_ids.shape[0] > 0: to_worker = self.rpc_router.get_to_worker(pidx) futs.append( rpc_async( to_worker, self.rpc_call_id, args=(remote_ids.cpu(), is_node_feat, input_type), )) indexes.append(torch.masked_select(input_order, remote_mask)) collect_fut = torch.futures.collect_all(futs) res_fut = torch.futures.Future() def when_finish(*_): try: fut_list = collect_fut.wait() result = [] for i, fut in enumerate(fut_list): result.append((fut.wait(), indexes[i])) except Exception as e: res_fut.set_exception(e) else: res_fut.set_result(result) collect_fut.add_done_callback(when_finish) return res_fut def _rpc_local_feature_get( self, index: Tensor, is_node_feat: bool = True, input_type: Optional[Union[NodeType, EdgeType]] = None, ) -> Tensor: r"""Lookup of features in remote nodes.""" if self.meta['is_hetero']: feat = self if is_node_feat: kwargs = dict(group_name=input_type, attr_name='x') ret_feat = feat.get_tensor_from_global_id( index=index, **kwargs) else: kwargs = dict(group_name=input_type, attr_name='edge_attr') ret_feat = feat.get_tensor_from_global_id( index=index, **kwargs) else: feat = self if is_node_feat: kwargs = dict(group_name=None, attr_name='x') ret_feat = feat.get_tensor_from_global_id( index=index, **kwargs) else: kwargs = dict(group_name=(None, None), attr_name='edge_attr') ret_feat = feat.get_tensor_from_global_id( index=index, **kwargs) return ret_feat # Initialization ##########################################################
[docs] @classmethod def from_data( cls, node_id: Tensor, x: Optional[Tensor] = None, y: Optional[Tensor] = None, edge_id: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, ) -> 'LocalFeatureStore': r"""Creates a local feature store from homogeneous :pyg:`PyG` tensors. Args: node_id (torch.Tensor): The global identifier for every local node. x (torch.Tensor, optional): The node features. (default: :obj:`None`) y (torch.Tensor, optional): The node labels. (default: :obj:`None`) edge_id (torch.Tensor, optional): The global identifier for every local edge. (default: :obj:`None`) edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) """ feat_store = cls() feat_store.put_global_id(node_id, group_name=None) if x is not None: feat_store.put_tensor(x, group_name=None, attr_name='x') if y is not None: feat_store.put_tensor(y, group_name=None, attr_name='y') if edge_id is not None: feat_store.put_global_id(edge_id, group_name=(None, None)) if edge_attr is not None: if edge_id is None: raise ValueError("'edge_id' needs to be present in case " "'edge_attr' is passed") feat_store.put_tensor(edge_attr, group_name=(None, None), attr_name='edge_attr') return feat_store
[docs] @classmethod def from_hetero_data( cls, node_id_dict: Dict[NodeType, Tensor], x_dict: Optional[Dict[NodeType, Tensor]] = None, y_dict: Optional[Dict[NodeType, Tensor]] = None, edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None, edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, ) -> 'LocalFeatureStore': r"""Creates a local graph store from heterogeneous :pyg:`PyG` tensors. Args: node_id_dict (Dict[NodeType, torch.Tensor]): The global identifier for every local node of every node type. x_dict (Dict[NodeType, torch.Tensor], optional): The node features of every node type. (default: :obj:`None`) y_dict (Dict[NodeType, torch.Tensor], optional): The node labels of every node type. (default: :obj:`None`) edge_id_dict (Dict[EdgeType, torch.Tensor], optional): The global identifier for every local edge of every edge types. (default: :obj:`None`) edge_attr_dict (Dict[EdgeType, torch.Tensor], optional): The edge features of every edge type. (default: :obj:`None`) """ feat_store = cls() for node_type, node_id in node_id_dict.items(): feat_store.put_global_id(node_id, group_name=node_type) if x_dict is not None: for node_type, x in x_dict.items(): feat_store.put_tensor(x, group_name=node_type, attr_name='x') if y_dict is not None: for node_type, y in y_dict.items(): feat_store.put_tensor(y, group_name=node_type, attr_name='y') if edge_id_dict is not None: for edge_type, edge_id in edge_id_dict.items(): feat_store.put_global_id(edge_id, group_name=edge_type) if edge_attr_dict is not None: for edge_type, edge_attr in edge_attr_dict.items(): if edge_id_dict is None or edge_type not in edge_id_dict: raise ValueError("'edge_id' needs to be present in case " "'edge_attr' is passed") feat_store.put_tensor(edge_attr, group_name=edge_type, attr_name='edge_attr') return feat_store
@classmethod def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore': part_dir = osp.join(root, f'part_{pid}') assert osp.exists(part_dir) feat_store = cls() ( meta, num_partitions, partition_idx, node_pb, edge_pb, ) = load_partition_info(root, pid) feat_store.num_partitions = num_partitions feat_store.partition_idx = partition_idx feat_store.node_feat_pb = node_pb feat_store.edge_feat_pb = edge_pb feat_store.meta = meta node_feats: Optional[Dict[str, Any]] = None if osp.exists(osp.join(part_dir, 'node_feats.pt')): node_feats = torch.load(osp.join(part_dir, 'node_feats.pt')) edge_feats: Optional[Dict[str, Any]] = None if osp.exists(osp.join(part_dir, 'edge_feats.pt')): edge_feats = torch.load(osp.join(part_dir, 'edge_feats.pt')) if not meta['is_hetero'] and node_feats is not None: feat_store.put_global_id(node_feats['global_id'], group_name=None) for key, value in node_feats['feats'].items(): feat_store.put_tensor(value, group_name=None, attr_name=key) if 'time' in node_feats: feat_store.put_tensor(node_feats['time'], group_name=None, attr_name='time') if not meta['is_hetero'] and edge_feats is not None: if 'global_id' in edge_feats: feat_store.put_global_id(edge_feats['global_id'], group_name=(None, None)) if 'feats' in edge_feats: for key, value in edge_feats['feats'].items(): feat_store.put_tensor(value, group_name=(None, None), attr_name=key) if 'edge_time' in edge_feats: feat_store.put_tensor(edge_feats['edge_time'], group_name=(None, None), attr_name='edge_time') if meta['is_hetero'] and node_feats is not None: for node_type, node_feat in node_feats.items(): feat_store.put_global_id(node_feat['global_id'], group_name=node_type) for key, value in node_feat['feats'].items(): feat_store.put_tensor(value, group_name=node_type, attr_name=key) if 'time' in node_feat: feat_store.put_tensor(node_feat['time'], group_name=node_type, attr_name='time') if meta['is_hetero'] and edge_feats is not None: for edge_type, edge_feat in edge_feats.items(): if 'global_id' in edge_feat: feat_store.put_global_id(edge_feat['global_id'], group_name=edge_type) if 'feats' in edge_feat: for key, value in edge_feat['feats'].items(): feat_store.put_tensor(value, group_name=edge_type, attr_name=key) if 'edge_time' in edge_feat: feat_store.put_tensor(edge_feat['edge_time'], group_name=edge_type, attr_name='edge_time') return feat_store