torch_geometric.nn.models.TGNMemory

class TGNMemory(num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable)[source]

Bases: Module

The Temporal Graph Network (TGN) memory model from the “Temporal Graph Networks for Deep Learning on Dynamic Graphs” paper.

Note

For an example of using TGN, see examples/tgn.py.

Parameters:
  • num_nodes (int) – The number of nodes to save memories for.

  • raw_msg_dim (int) – The raw message dimensionality.

  • memory_dim (int) – The hidden memory dimensionality.

  • time_dim (int) – The time encoding dimensionality.

  • message_module (torch.nn.Module) – The message function which combines source and destination node memory embeddings, the raw message and the time encoding.

  • aggregator_module (torch.nn.Module) – The message aggregator function which aggregates messages to the same destination into a single representation.

forward(n_id: Tensor) Tuple[Tensor, Tensor][source]

Returns, for all nodes n_id, their current memory and their last updated timestamp.

reset_parameters()[source]

Resets all learnable parameters of the module.

reset_state()[source]

Resets the memory to its initial state.

detach()[source]

Detaches the memory from gradient computation.

update_state(src: Tensor, dst: Tensor, t: Tensor, raw_msg: Tensor)[source]

Updates the memory with newly encountered interactions (src, dst, t, raw_msg).

train(mode: bool = True)[source]

Sets the module in training mode.