class TGNMemory(num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable)[source]

Bases: Module

The Temporal Graph Network (TGN) memory model from the “Temporal Graph Networks for Deep Learning on Dynamic Graphs” paper.


For an example of using TGN, see examples/

  • num_nodes (int) – The number of nodes to save memories for.

  • raw_msg_dim (int) – The raw message dimensionality.

  • memory_dim (int) – The hidden memory dimensionality.

  • time_dim (int) – The time encoding dimensionality.

  • message_module (torch.nn.Module) – The message function which combines source and destination node memory embeddings, the raw message and the time encoding.

  • aggregator_module (torch.nn.Module) – The message aggregator function which aggregates messages to the same destination into a single representation.

forward(n_id: Tensor) Tuple[Tensor, Tensor][source]

Returns, for all nodes n_id, their current memory and their last updated timestamp.


Resets all learnable parameters of the module.


Resets the memory to its initial state.


Detaches the memory from gradient computation.

update_state(src: Tensor, dst: Tensor, t: Tensor, raw_msg: Tensor)[source]

Updates the memory with newly encountered interactions (src, dst, t, raw_msg).

train(mode: bool = True)[source]

Sets the module in training mode.