
class RENet(num_nodes: int, num_rels: int, hidden_channels: int, seq_len: int, num_layers: int = 1, dropout: float = 0.0, bias: bool = True)[source]

Bases: Module

The Recurrent Event Network model from the “Recurrent Event Network for Reasoning over Temporal Knowledge Graphs” paper.

\[f_{\mathbf{\Theta}}(\mathbf{e}_s, \mathbf{e}_r, \mathbf{h}^{(t-1)}(s, r))\]

based on a RNN encoder

\[\mathbf{h}^{(t)}(s, r) = \textrm{RNN}(\mathbf{e}_s, \mathbf{e}_r, g(\mathcal{O}^{(t)}_r(s)), \mathbf{h}^{(t-1)}(s, r))\]

where \(\mathbf{e}_s\) and \(\mathbf{e}_r\) denote entity and relation embeddings, and \(\mathcal{O}^{(t)}_r(s)\) represents the set of objects interacted with subject \(s\) under relation \(r\) at timestamp \(t\). This model implements \(g\) as the Mean Aggregator and \(f_{\mathbf{\Theta}}\) as a linear projection.

  • num_nodes (int) – The number of nodes in the knowledge graph.

  • num_rels (int) – The number of relations in the knowledge graph.

  • hidden_channels (int) – Hidden size of node and relation embeddings.

  • seq_len (int) – The sequence length of past events.

  • num_layers (int, optional) – The number of recurrent layers. (default: 1)

  • dropout (float) – If non-zero, introduces a dropout layer before the final prediction. (default: 0.)

  • bias (bool, optional) – If set to False, all layers will not learn an additive bias. (default: True)

forward(data: Data) Tuple[Tensor, Tensor][source]

Given a data batch, computes the forward pass.


data ( – The input data, holding subject sub, relation rel and object obj information with shape [batch_size]. In addition, data needs to hold history information for subjects, given by a vector of node indices h_sub and their relative timestamps h_sub_t and batch assignments h_sub_batch. The same information must be given for objects (h_obj, h_obj_t, h_obj_batch).

static pre_transform(seq_len: int) Callable[source]

Precomputes history objects.

\[\{ \mathcal{O}^{(t-k-1)}_r(s), \ldots, \mathcal{O}^{(t-1)}_r(s) \}\]

of a torch_geometric.datasets.icews.EventDataset with \(k\) denoting the sequence length seq_len.

test(logits: Tensor, y: Tensor) Tensor[source]

Given ground-truth y, computes Mean Reciprocal Rank (MRR) and Hits at 1/3/10.