torch_geometric.nn.kge.ComplEx

class ComplEx(num_nodes: int, num_relations: int, hidden_channels: int, sparse: bool = False)[source]

Bases: KGEModel

The ComplEx model from the “Complex Embeddings for Simple Link Prediction” paper.

ComplEx models relations as complex-valued bilinear mappings between head and tail entities using the Hermetian dot product. The entities and relations are embedded in different dimensional spaces, resulting in the scoring function:

\[d(h, r, t) = Re(< \mathbf{e}_h, \mathbf{e}_r, \mathbf{e}_t>)\]

Note

For an example of using the ComplEx model, see examples/kge_fb15k_237.py.

Parameters:
  • num_nodes (int) – The number of nodes/entities in the graph.

  • num_relations (int) – The number of relations in the graph.

  • hidden_channels (int) – The hidden embedding size.

  • sparse (bool, optional) – If set to True, gradients w.r.t. to the embedding matrices will be sparse. (default: False)

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]

Returns the score for the given triplet.

Parameters:
  • head_index (torch.Tensor) – The head indices.

  • rel_type (torch.Tensor) – The relation type.

  • tail_index (torch.Tensor) – The tail indices.

loss(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor[source]

Returns the loss value for the given triplet.

Parameters:
  • head_index (torch.Tensor) – The head indices.

  • rel_type (torch.Tensor) – The relation type.

  • tail_index (torch.Tensor) – The tail indices.