torch_geometric.nn.models.TGNMemory
- class TGNMemory(num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable)[source]
Bases:
Module
The Temporal Graph Network (TGN) memory model from the “Temporal Graph Networks for Deep Learning on Dynamic Graphs” paper.
Note
For an example of using TGN, see examples/tgn.py.
- Parameters:
num_nodes (int) – The number of nodes to save memories for.
raw_msg_dim (int) – The raw message dimensionality.
memory_dim (int) – The hidden memory dimensionality.
time_dim (int) – The time encoding dimensionality.
message_module (torch.nn.Module) – The message function which combines source and destination node memory embeddings, the raw message and the time encoding.
aggregator_module (torch.nn.Module) – The message aggregator function which aggregates messages to the same destination into a single representation.
- forward(n_id: Tensor) Tuple[Tensor, Tensor] [source]
Returns, for all nodes
n_id
, their current memory and their last updated timestamp.