from abc import abstractmethod
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
from torch_geometric.data import Data, FeatureStore, HeteroData
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
from torch_geometric.typing import InputEdges, InputNodes
[docs]class RAGFeatureStore(Protocol):
"""Feature store template for remote GNN RAG backend."""
[docs] @abstractmethod
def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
"""Makes a comparison between the query and all the nodes to get all
the closest nodes. Return the indices of the nodes that are to be seeds
for the RAG Sampler.
"""
...
[docs] @abstractmethod
def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
"""Makes a comparison between the query and all the edges to get all
the closest nodes. Returns the edge indices that are to be the seeds
for the RAG Sampler.
"""
...
[docs] @abstractmethod
def load_subgraph(
self, sample: Union[SamplerOutput, HeteroSamplerOutput]
) -> Union[Data, HeteroData]:
"""Combines sampled subgraph output with features in a Data object."""
...
[docs]class RAGGraphStore(Protocol):
"""Graph store template for remote GNN RAG backend."""
[docs] @abstractmethod
def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
**kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
"""Sample a subgraph using the seeded nodes and edges."""
...
[docs] @abstractmethod
def register_feature_store(self, feature_store: FeatureStore):
"""Register a feature store to be used with the sampler. Samplers need
info from the feature store in order to work properly on HeteroGraphs.
"""
...
# TODO: Make compatible with Heterographs
[docs]class RAGQueryLoader:
"""Loader meant for making RAG queries from a remote backend."""
def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
local_filter: Optional[Callable[[Data, Any], Data]] = None,
seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
seed_edges_kwargs: Optional[Dict[str, Any]] = None,
sampler_kwargs: Optional[Dict[str, Any]] = None,
loader_kwargs: Optional[Dict[str, Any]] = None):
"""Loader meant for making queries from a remote backend.
Args:
data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore
and GraphStore to load from. Assumed to conform to the
protocols listed above.
local_filter (Optional[Callable[[Data, Any], Data]], optional):
Optional local transform to apply to data after retrieval.
Defaults to None.
seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters
to pass into process for fetching seed nodes. Defaults to None.
seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters
to pass into process for fetching seed edges. Defaults to None.
sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to
pass into process for sampling graph. Defaults to None.
loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to
pass into process for loading graph features. Defaults to None.
"""
fstore, gstore = data
self.feature_store = fstore
self.graph_store = gstore
self.graph_store.register_feature_store(self.feature_store)
self.local_filter = local_filter
self.seed_nodes_kwargs = seed_nodes_kwargs or {}
self.seed_edges_kwargs = seed_edges_kwargs or {}
self.sampler_kwargs = sampler_kwargs or {}
self.loader_kwargs = loader_kwargs or {}
[docs] def query(self, query: Any) -> Data:
"""Retrieve a subgraph associated with the query with all its feature
attributes.
"""
seed_nodes = self.feature_store.retrieve_seed_nodes(
query, **self.seed_nodes_kwargs)
seed_edges = self.feature_store.retrieve_seed_edges(
query, **self.seed_edges_kwargs)
subgraph_sample = self.graph_store.sample_subgraph(
seed_nodes, seed_edges, **self.sampler_kwargs)
data = self.feature_store.load_subgraph(sample=subgraph_sample,
**self.loader_kwargs)
if self.local_filter:
data = self.local_filter(data, query)
return data