torch_geometric.nn.models.GRetriever

class GRetriever(llm: LLM, gnn: Module, use_lora: bool = False, mlp_out_channels: int = 4096)[source]

Bases: Module

The G-Retriever model from the “G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and Question Answering” paper.

Parameters:
  • llm (LLM) – The LLM to use.

  • gnn (torch.nn.Module) – The GNN to use.

  • use_lora (bool, optional) – If set to True, will use LORA from peft for training the LLM, see here for details. (default: False)

  • mlp_out_channels (int, optional) – The size of each graph embedding after projection. (default: 4096)

Warning

This module has been tested with the following HuggingFace models

  • llm_to_use="meta-llama/Llama-2-7b-chat-hf"

  • llm_to_use="google/gemma-7b"

and may not work with other models. See other models at HuggingFace Models and let us know if you encounter any issues.

Note

For an example of using GRetriever, see examples/llm/g_retriever.py.

forward(question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, label: List[str], edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None)[source]

The forward pass.

Parameters:
  • question (List[str]) – The questions/prompts.

  • x (torch.Tensor) – The input node features.

  • edge_index (torch.Tensor) – The edge indices.

  • batch (torch.Tensor) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each element to a specific example.

  • label (List[str]) – The answers/labels.

  • edge_attr (torch.Tensor, optional) – The edge features (if supported by the GNN). (default: None)

  • additional_text_context (List[str], optional) – Additional context to give to the LLM, such as textified knowledge graphs. (default: None)

inference(question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None, max_out_tokens: Optional[int] = 32)[source]

The inference pass.

Parameters:
  • question (List[str]) – The questions/prompts.

  • x (torch.Tensor) – The input node features.

  • edge_index (torch.Tensor) – The edge indices.

  • batch (torch.Tensor) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each element to a specific example.

  • edge_attr (torch.Tensor, optional) – The edge features (if supported by the GNN). (default: None)

  • additional_text_context (List[str], optional) – Additional context to give to the LLM, such as textified knowledge graphs. (default: None)

  • max_out_tokens (int, optional) – How many tokens for the LLM to generate. (default: 32)