torch_geometric.nn.kge.KGEModel
- class KGEModel(num_nodes: int, num_relations: int, hidden_channels: int, sparse: bool = False)[source]
Bases:
Module
An abstract base class for implementing custom KGE models.
- Parameters:
- forward(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor [source]
Returns the score for the given triplet.
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
- Return type:
- loss(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tensor [source]
Returns the loss value for the given triplet.
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
- Return type:
- loader(head_index: Tensor, rel_type: Tensor, tail_index: Tensor, **kwargs) Tensor [source]
Returns a mini-batch loader that samples a subset of triplets.
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
**kwargs (optional) – Additional arguments of
torch.utils.data.DataLoader
, such asbatch_size
,shuffle
,drop_last
ornum_workers
.
- Return type:
- test(head_index: Tensor, rel_type: Tensor, tail_index: Tensor, batch_size: int, k: int = 10, log: bool = True) Tuple[float, float, float] [source]
Evaluates the model quality by computing Mean Rank, MRR and Hits@:math:k across all possible tail entities.
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
batch_size (int) – The batch size to use for evaluating.
k (int, optional) – The \(k\) in Hits @ \(k\). (default:
10
)log (bool, optional) – If set to
False
, will not print a progress bar to the console. (default:True
)
- Return type:
- random_sample(head_index: Tensor, rel_type: Tensor, tail_index: Tensor) Tuple[Tensor, Tensor, Tensor] [source]
Randomly samples negative triplets by either replacing the head or the tail (but not both).
- Parameters:
head_index (torch.Tensor) – The head indices.
rel_type (torch.Tensor) – The relation type.
tail_index (torch.Tensor) – The tail indices.
- Return type: