Source code for torch_geometric.metrics.link_pred

from typing import Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.utils import cumsum, scatter

try:
    import torchmetrics  # noqa
    WITH_TORCHMETRICS = True
    BaseMetric = torchmetrics.Metric
except Exception:
    WITH_TORCHMETRICS = False
    BaseMetric = torch.nn.Module  # type: ignore


class LinkPredMetric(BaseMetric):
    r"""An abstract class for computing link prediction retrieval metrics.

    Args:
        k (int): The number of top-:math:`k` predictions to evaluate against.
    """
    is_differentiable: bool = False
    full_state_update: bool = False
    higher_is_better: Optional[bool] = None

    def __init__(self, k: int) -> None:
        super().__init__()

        if k <= 0:
            raise ValueError(f"'k' needs to be a positive integer in "
                             f"'{self.__class__.__name__}' (got {k})")

        self.k = k

        self.accum: Tensor
        self.total: Tensor

        if WITH_TORCHMETRICS:
            self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
            self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
        else:
            self.register_buffer('accum', torch.tensor(0.))
            self.register_buffer('total', torch.tensor(0))

    def update(
        self,
        pred_index_mat: Tensor,
        edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
    ) -> None:
        r"""Updates the state variables based on the current mini-batch
        prediction.

        :meth:`update` can be repeated multiple times to accumulate the results
        of successive predictions, *e.g.*, inside a mini-batch training or
        evaluation loop.

        Args:
            pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
                every example in the mini-batch with shape
                :obj:`[batch_size, k]`.
            edge_label_index (torch.Tensor): The ground-truth indices for every
                example in the mini-batch, given in COO format of shape
                :obj:`[2, num_ground_truth_indices]`.
        """
        if pred_index_mat.size(1) != self.k:
            raise ValueError(f"Expected 'pred_index_mat' to hold {self.k} "
                             f"many indices for every entry "
                             f"(got {pred_index_mat.size(1)})")

        # Compute a boolean matrix indicating if the k-th prediction is part of
        # the ground-truth. We do this by flattening both prediction and
        # target indices, and then determining overlaps via `torch.isin`.
        max_index = max(  # type: ignore
            pred_index_mat.max() if pred_index_mat.numel() > 0 else 0,
            edge_label_index[1].max()
            if edge_label_index[1].numel() > 0 else 0,
        ) + 1
        arange = torch.arange(
            start=0,
            end=max_index * pred_index_mat.size(0),  # type: ignore
            step=max_index,  # type: ignore
            device=pred_index_mat.device,
        ).view(-1, 1)
        flat_pred_index = (pred_index_mat + arange).view(-1)
        flat_y_index = max_index * edge_label_index[0] + edge_label_index[1]

        pred_isin_mat = torch.isin(flat_pred_index, flat_y_index)
        pred_isin_mat = pred_isin_mat.view(pred_index_mat.size())

        # Compute the number of targets per example:
        y_count = scatter(
            torch.ones_like(edge_label_index[0]),
            edge_label_index[0],
            dim=0,
            dim_size=pred_index_mat.size(0),
            reduce='sum',
        )

        metric = self._compute(pred_isin_mat, y_count)

        self.accum += metric.sum()
        self.total += (y_count > 0).sum()

    def compute(self) -> Tensor:
        r"""Computes the final metric value."""
        if self.total == 0:
            return torch.zeros_like(self.accum)
        return self.accum / self.total

    def reset(self) -> None:
        r"""Reset metric state variables to their default value."""
        if WITH_TORCHMETRICS:
            super().reset()
        else:
            self.accum.zero_()
            self.total.zero_()

    def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
        r"""Compute the specific metric.
        To be implemented separately for each metric class.

        Args:
            pred_isin_mat (torch.Tensor): A boolean matrix whose :obj:`(i,k)`
                element indicates if the :obj:`k`-th prediction for the
                :obj:`i`-th example is correct or not.
            y_count (torch.Tensor): A vector indicating the number of
                ground-truth labels for each example.
        """
        raise NotImplementedError

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(k={self.k})'


[docs]class LinkPredPrecision(LinkPredMetric): r"""A link prediction metric to compute Precision @ :math:`k`. Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: return pred_isin_mat.sum(dim=-1) / self.k
[docs]class LinkPredRecall(LinkPredMetric): r"""A link prediction metric to compute Recall @ :math:`k`. Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: return pred_isin_mat.sum(dim=-1) / y_count.clamp(min=1e-7)
[docs]class LinkPredF1(LinkPredMetric): r"""A link prediction metric to compute F1 @ :math:`k`. Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: isin_count = pred_isin_mat.sum(dim=-1) precision = isin_count / self.k recall = isin_count = isin_count / y_count.clamp(min=1e-7) return 2 * precision * recall / (precision + recall).clamp(min=1e-7)
[docs]class LinkPredMAP(LinkPredMetric): r"""A link prediction metric to compute MAP @ :math:`k` (Mean Average Precision). Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: cum_precision = (torch.cumsum(pred_isin_mat, dim=1) / torch.arange(1, self.k + 1, device=y_count.device)) return ((cum_precision * pred_isin_mat).sum(dim=-1) / y_count.clamp(min=1e-7, max=self.k))
[docs]class LinkPredNDCG(LinkPredMetric): r"""A link prediction metric to compute the NDCG @ :math:`k` (Normalized Discounted Cumulative Gain). Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True def __init__(self, k: int): super().__init__(k=k) dtype = torch.get_default_dtype() multiplier = 1.0 / torch.arange(2, k + 2, dtype=dtype).log2() self.multiplier: Tensor self.register_buffer('multiplier', multiplier) self.idcg: Tensor self.register_buffer('idcg', cumsum(multiplier)) def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: dcg = (pred_isin_mat * self.multiplier.view(1, -1)).sum(dim=-1) idcg = self.idcg[y_count.clamp(max=self.k)] out = dcg / idcg out[out.isnan() | out.isinf()] = 0.0 return out
[docs]class LinkPredMRR(LinkPredMetric): r"""A link prediction metric to compute the MRR @ :math:`k` (Mean Reciprocal Rank). Args: k (int): The number of top-:math:`k` predictions to evaluate against. """ higher_is_better: bool = True def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor: rank = pred_isin_mat.type(torch.uint8).argmax(dim=-1) is_correct = pred_isin_mat.gather(1, rank.view(-1, 1)).view(-1) reciprocals = 1.0 / (rank + 1) reciprocals[~is_correct] = 0.0 return reciprocals