Source code for torch_geometric.data.graph_store

r"""This class defines the abstraction for a backend-agnostic graph store. The
goal of the graph store is to abstract away all graph edge index memory
management so that varying implementations can allow for independent scale-out.

This particular graph store abstraction makes a few key assumptions:
* The edge indices we care about storing are represented either in COO, CSC,
  or CSR format. They can be uniquely identified by an edge type (in PyG,
  this is a tuple of the source node, relation type, and destination node).
* Edge indices are static once they are stored in the graph. That is, we do not
  support dynamic modification of edge indices once they have been inserted
  into the graph store.

It is the job of a graph store implementor class to handle these assumptions
properly. For example, a simple in-memory graph store implementation may
concatenate all metadata values with an edge index and use this as a unique
index in a KV store. More complicated implementations may choose to partition
the graph in interesting manners based on the provided metadata.
"""
import copy
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

from torch import Tensor

from torch_geometric.typing import EdgeTensorType, EdgeType, OptTensor
from torch_geometric.utils import index_sort
from torch_geometric.utils.mixin import CastMixin
from torch_geometric.utils.sparse import index2ptr, ptr2index

# The output of converting between two types in the GraphStore is a Tuple of
# dictionaries: row, col, and perm. The dictionaries are keyed by the edge
# type of the input edge attribute.
#   * The row dictionary contains the row tensor for COO, the row pointer for
#     CSR, or the row tensor for CSC
#   * The col dictionary contains the col tensor for COO, the col tensor for
#     CSR, or the col pointer for CSC
#   * The perm dictionary contains the permutation of edges that was applied
#     in converting between formats, if applicable.
ConversionOutputType = Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor],
                             Dict[EdgeType, OptTensor]]


class EdgeLayout(Enum):
    COO = 'coo'
    CSC = 'csc'
    CSR = 'csr'


[docs]@dataclass class EdgeAttr(CastMixin): r"""Defines the attributes of a :obj:`GraphStore` edge. It holds all the parameters necessary to uniquely identify an edge from the :class:`GraphStore`. Note that the order of the attributes is important; this is the order in which attributes must be provided for indexing calls. :class:`GraphStore` implementations can define a different ordering by overriding :meth:`EdgeAttr.__init__`. """ # The type of the edge: edge_type: EdgeType # The layout of the edge representation: layout: EdgeLayout # Whether the edge index is sorted by destination node. Useful for # avoiding sorting costs when performing neighbor sampling, and only # meaningful for COO (CSC is sorted and CSR is not sorted by definition): is_sorted: bool = False # The number of source and destination nodes in this edge type: size: Optional[Tuple[int, int]] = None # NOTE we define __init__ to force-cast layout def __init__( self, edge_type: EdgeType, layout: EdgeLayout, is_sorted: bool = False, size: Optional[Tuple[int, int]] = None, ): layout = EdgeLayout(layout) if layout == EdgeLayout.CSR and is_sorted: raise ValueError("Cannot create a 'CSR' edge attribute with " "option 'is_sorted=True'") if layout == EdgeLayout.CSC: is_sorted = True self.edge_type = edge_type self.layout = layout self.is_sorted = is_sorted self.size = size
[docs]class GraphStore(ABC): r"""An abstract base class to access edges from a remote graph store. Args: edge_attr_cls (EdgeAttr, optional): A user-defined :class:`EdgeAttr` class to customize the required attributes and their ordering to uniquely identify edges. (default: :obj:`None`) """ def __init__(self, edge_attr_cls: Optional[Any] = None): super().__init__() self.__dict__['_edge_attr_cls'] = edge_attr_cls or EdgeAttr # Core (CRUD) ############################################################# @abstractmethod def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> bool: r"""To be implemented by :class:`GraphStore` subclasses.""" pass
[docs] def put_edge_index(self, edge_index: EdgeTensorType, *args, **kwargs) -> bool: r"""Synchronously adds an :obj:`edge_index` tuple to the :class:`GraphStore`. Returns whether insertion was successful. Args: edge_index (Tuple[torch.Tensor, torch.Tensor]): The :obj:`edge_index` tuple in a format specified in :class:`EdgeAttr`. *args: Arguments passed to :class:`EdgeAttr`. **kwargs: Keyword arguments passed to :class:`EdgeAttr`. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) return self._put_edge_index(edge_index, edge_attr)
@abstractmethod def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: r"""To be implemented by :class:`GraphStore` subclasses.""" pass
[docs] def get_edge_index(self, *args, **kwargs) -> EdgeTensorType: r"""Synchronously obtains an :obj:`edge_index` tuple from the :class:`GraphStore`. Args: *args: Arguments passed to :class:`EdgeAttr`. **kwargs: Keyword arguments passed to :class:`EdgeAttr`. Raises: KeyError: If the :obj:`edge_index` corresponding to the input :class:`EdgeAttr` was not found. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) edge_index = self._get_edge_index(edge_attr) if edge_index is None: raise KeyError(f"'edge_index' for '{edge_attr}' not found") return edge_index
@abstractmethod def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool: r"""To be implemented by :class:`GraphStore` subclasses.""" pass
[docs] def remove_edge_index(self, *args, **kwargs) -> bool: r"""Synchronously deletes an :obj:`edge_index` tuple from the :class:`GraphStore`. Returns whether deletion was successful. Args: *args: Arguments passed to :class:`EdgeAttr`. **kwargs: Keyword arguments passed to :class:`EdgeAttr`. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) return self._remove_edge_index(edge_attr)
[docs] @abstractmethod def get_all_edge_attrs(self) -> List[EdgeAttr]: r"""Returns all registered edge attributes.""" pass
# Layout Conversion #######################################################
[docs] def coo( self, edge_types: Optional[List[Any]] = None, store: bool = False, ) -> ConversionOutputType: r"""Returns the edge indices in the :class:`GraphStore` in COO format. Args: edge_types (List[Any], optional): The edge types of edge indices to obtain. If set to :obj:`None`, will return the edge indices of all existing edge types. (default: :obj:`None`) store (bool, optional): Whether to store converted edge indices in the :class:`GraphStore`. (default: :obj:`False`) """ return self._edges_to_layout(EdgeLayout.COO, edge_types, store)
[docs] def csr( self, edge_types: Optional[List[Any]] = None, store: bool = False, ) -> ConversionOutputType: r"""Returns the edge indices in the :class:`GraphStore` in CSR format. Args: edge_types (List[Any], optional): The edge types of edge indices to obtain. If set to :obj:`None`, will return the edge indices of all existing edge types. (default: :obj:`None`) store (bool, optional): Whether to store converted edge indices in the :class:`GraphStore`. (default: :obj:`False`) """ return self._edges_to_layout(EdgeLayout.CSR, edge_types, store)
[docs] def csc( self, edge_types: Optional[List[Any]] = None, store: bool = False, ) -> ConversionOutputType: r"""Returns the edge indices in the :class:`GraphStore` in CSC format. Args: edge_types (List[Any], optional): The edge types of edge indices to obtain. If set to :obj:`None`, will return the edge indices of all existing edge types. (default: :obj:`None`) store (bool, optional): Whether to store converted edge indices in the :class:`GraphStore`. (default: :obj:`False`) """ return self._edges_to_layout(EdgeLayout.CSC, edge_types, store)
# Python built-ins ######################################################## def __setitem__(self, key: EdgeAttr, value: EdgeTensorType): self.put_edge_index(value, key) def __getitem__(self, key: EdgeAttr) -> Optional[EdgeTensorType]: return self.get_edge_index(key) def __delitem__(self, key: EdgeAttr): return self.remove_edge_index(key) def __repr__(self) -> str: return f'{self.__class__.__name__}()' # Helper methods ########################################################## def _edge_to_layout( self, attr: EdgeAttr, layout: EdgeLayout, store: bool = False, ) -> Tuple[Tensor, Tensor, OptTensor]: (row, col), perm = self.get_edge_index(attr), None if layout == EdgeLayout.COO: # COO output requested: if attr.layout == EdgeLayout.CSR: # CSR->COO row = ptr2index(row) elif attr.layout == EdgeLayout.CSC: # CSC->COO col = ptr2index(col) elif layout == EdgeLayout.CSR: # CSR output requested: if attr.layout == EdgeLayout.CSC: # CSC->COO col = ptr2index(col) if attr.layout != EdgeLayout.CSR: # COO->CSR num_rows = attr.size[0] if attr.size else int(row.max()) + 1 row, perm = index_sort(row, max_value=num_rows) col = col[perm] row = index2ptr(row, num_rows) else: # CSC output requested: if attr.layout == EdgeLayout.CSR: # CSR->COO row = ptr2index(row) if attr.layout != EdgeLayout.CSC: # COO->CSC if hasattr(self, 'meta') and self.meta.get('is_hetero', False): # Hotfix for `LocalGraphStore`, where in heterogeneous # graphs, edge indices for different edge types have # continuous indices not starting at 0. num_cols = int(col.max()) + 1 elif attr.size is not None: num_cols = attr.size[1] else: num_cols = int(col.max()) + 1 if not attr.is_sorted: # Not sorted by destination. col, perm = index_sort(col, max_value=num_cols) row = row[perm] col = index2ptr(col, num_cols) if attr.layout != layout and store: attr = copy.copy(attr) attr.layout = layout if perm is not None: attr.is_sorted = False self.put_edge_index((row, col), attr) return row, col, perm def _edges_to_layout( self, layout: EdgeLayout, edge_types: Optional[List[Any]] = None, store: bool = False, ) -> ConversionOutputType: edge_attrs: List[EdgeAttr] = self.get_all_edge_attrs() if hasattr(self, 'meta'): # `LocalGraphStore` hack. is_hetero = self.meta.get('is_hetero', False) else: is_hetero = all(attr.edge_type is not None for attr in edge_attrs) if not is_hetero: return self._edge_to_layout(edge_attrs[0], layout, store) # Obtain all edge attributes, grouped by type: edge_type_attrs: Dict[EdgeType, List[EdgeAttr]] = defaultdict(list) for attr in self.get_all_edge_attrs(): edge_type_attrs[attr.edge_type].append(attr) # Check that requested edge types exist and filter: if edge_types is not None: for edge_type in edge_types: if edge_type not in edge_type_attrs: raise ValueError(f"The 'edge_index' of type '{edge_type}' " f"was not found in the graph store.") edge_type_attrs = { key: attr for key, attr in edge_type_attrs.items() if key in edge_types } # Convert layout from its most favorable original layout: row_dict, col_dict, perm_dict = {}, {}, {} for edge_type, attrs in edge_type_attrs.items(): layouts = [attr.layout for attr in attrs] if layout in layouts: # No conversion needed. attr = attrs[layouts.index(layout)] elif EdgeLayout.COO in layouts: # Prefer COO for conversion. attr = attrs[layouts.index(EdgeLayout.COO)] elif EdgeLayout.CSC in layouts: attr = attrs[layouts.index(EdgeLayout.CSC)] elif EdgeLayout.CSR in layouts: attr = attrs[layouts.index(EdgeLayout.CSR)] row_dict[edge_type], col_dict[edge_type], perm_dict[edge_type] = ( self._edge_to_layout(attr, layout, store)) return row_dict, col_dict, perm_dict