from typing import Tuple
import torch
from torch import Tensor
from torch.nn import Embedding
from tqdm import tqdm
from torch_geometric.nn.kge.loader import KGTripletLoader
[docs]class KGEModel(torch.nn.Module):
r"""An abstract base class for implementing custom KGE models.
Args:
num_nodes (int): The number of nodes/entities in the graph.
num_relations (int): The number of relations in the graph.
hidden_channels (int): The hidden embedding size.
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
embedding matrices will be sparse. (default: :obj:`False`)
"""
def __init__(
self,
num_nodes: int,
num_relations: int,
hidden_channels: int,
sparse: bool = False,
):
super().__init__()
self.num_nodes = num_nodes
self.num_relations = num_relations
self.hidden_channels = hidden_channels
self.node_emb = Embedding(num_nodes, hidden_channels, sparse=sparse)
self.rel_emb = Embedding(num_relations, hidden_channels, sparse=sparse)
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.node_emb.reset_parameters()
self.rel_emb.reset_parameters()
[docs] def forward(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:
r"""Returns the score for the given triplet.
Args:
head_index (torch.Tensor): The head indices.
rel_type (torch.Tensor): The relation type.
tail_index (torch.Tensor): The tail indices.
"""
raise NotImplementedError
[docs] def loss(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:
r"""Returns the loss value for the given triplet.
Args:
head_index (torch.Tensor): The head indices.
rel_type (torch.Tensor): The relation type.
tail_index (torch.Tensor): The tail indices.
"""
raise NotImplementedError
[docs] def loader(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
**kwargs,
) -> Tensor:
r"""Returns a mini-batch loader that samples a subset of triplets.
Args:
head_index (torch.Tensor): The head indices.
rel_type (torch.Tensor): The relation type.
tail_index (torch.Tensor): The tail indices.
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as
:obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last`
or :obj:`num_workers`.
"""
return KGTripletLoader(head_index, rel_type, tail_index, **kwargs)
[docs] @torch.no_grad()
def test(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
batch_size: int,
k: int = 10,
log: bool = True,
) -> Tuple[float, float, float]:
r"""Evaluates the model quality by computing Mean Rank, MRR and
Hits@:math:`k` across all possible tail entities.
Args:
head_index (torch.Tensor): The head indices.
rel_type (torch.Tensor): The relation type.
tail_index (torch.Tensor): The tail indices.
batch_size (int): The batch size to use for evaluating.
k (int, optional): The :math:`k` in Hits @ :math:`k`.
(default: :obj:`10`)
log (bool, optional): If set to :obj:`False`, will not print a
progress bar to the console. (default: :obj:`True`)
"""
arange = range(head_index.numel())
arange = tqdm(arange) if log else arange
mean_ranks, reciprocal_ranks, hits_at_k = [], [], []
for i in arange:
h, r, t = head_index[i], rel_type[i], tail_index[i]
scores = []
tail_indices = torch.arange(self.num_nodes, device=t.device)
for ts in tail_indices.split(batch_size):
scores.append(self(h.expand_as(ts), r.expand_as(ts), ts))
rank = int((torch.cat(scores).argsort(
descending=True) == t).nonzero().view(-1))
mean_ranks.append(rank)
reciprocal_ranks.append(1 / (rank + 1))
hits_at_k.append(rank < k)
mean_rank = float(torch.tensor(mean_ranks, dtype=torch.float).mean())
mrr = float(torch.tensor(reciprocal_ranks, dtype=torch.float).mean())
hits_at_k = int(torch.tensor(hits_at_k).sum()) / len(hits_at_k)
return mean_rank, mrr, hits_at_k
[docs] @torch.no_grad()
def random_sample(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Randomly samples negative triplets by either replacing the head or
the tail (but not both).
Args:
head_index (torch.Tensor): The head indices.
rel_type (torch.Tensor): The relation type.
tail_index (torch.Tensor): The tail indices.
"""
# Random sample either `head_index` or `tail_index` (but not both):
num_negatives = head_index.numel() // 2
rnd_index = torch.randint(self.num_nodes, head_index.size(),
device=head_index.device)
head_index = head_index.clone()
head_index[:num_negatives] = rnd_index[:num_negatives]
tail_index = tail_index.clone()
tail_index[num_negatives:] = rnd_index[num_negatives:]
return head_index, rel_type, tail_index
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.num_nodes}, '
f'num_relations={self.num_relations}, '
f'hidden_channels={self.hidden_channels})')