torch_geometric.nn.models.GRetriever
- class GRetriever(llm: LLM, gnn: Module, use_lora: bool = False, mlp_out_channels: int = 4096)[source]
Bases:
Module
The G-Retriever model from the “G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and Question Answering” paper.
- Parameters:
Warning
This module has been tested with the following HuggingFace models
llm_to_use="meta-llama/Llama-2-7b-chat-hf"
llm_to_use="google/gemma-7b"
and may not work with other models. See other models at HuggingFace Models and let us know if you encounter any issues.
Note
For an example of using
GRetriever
, see examples/llm/g_retriever.py.- forward(question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, label: List[str], edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None)[source]
The forward pass.
- Parameters:
question (List[str]) – The questions/prompts.
x (torch.Tensor) – The input node features.
edge_index (torch.Tensor) – The edge indices.
batch (torch.Tensor) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each element to a specific example.
label (List[str]) – The answers/labels.
edge_attr (torch.Tensor, optional) – The edge features (if supported by the GNN). (default:
None
)additional_text_context (List[str], optional) – Additional context to give to the LLM, such as textified knowledge graphs. (default:
None
)
- inference(question: List[str], x: Tensor, edge_index: Tensor, batch: Tensor, edge_attr: Optional[Tensor] = None, additional_text_context: Optional[List[str]] = None, max_out_tokens: Optional[int] = 32)[source]
The inference pass.
- Parameters:
question (List[str]) – The questions/prompts.
x (torch.Tensor) – The input node features.
edge_index (torch.Tensor) – The edge indices.
batch (torch.Tensor) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each element to a specific example.
edge_attr (torch.Tensor, optional) – The edge features (if supported by the GNN). (default:
None
)additional_text_context (List[str], optional) – Additional context to give to the LLM, such as textified knowledge graphs. (default:
None
)max_out_tokens (int, optional) – How many tokens for the LLM to generate. (default:
32
)