from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch_geometric.utils import cumsum, scatter
try:
import torchmetrics # noqa
WITH_TORCHMETRICS = True
BaseMetric = torchmetrics.Metric
except Exception:
WITH_TORCHMETRICS = False
BaseMetric = torch.nn.Module # type: ignore
@dataclass(repr=False)
class LinkPredMetricData:
pred_index_mat: Tensor
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]
edge_label_weight: Optional[Tensor] = None
@property
def pred_rel_mat(self) -> Tensor:
r"""Returns a matrix indicating the relevance of the `k`-th prediction.
If :obj:`edge_label_weight` is not given, relevance will be denoted as
binary.
"""
if hasattr(self, '_pred_rel_mat'):
return self._pred_rel_mat # type: ignore
if self.edge_label_index[1].numel() == 0:
self._pred_rel_mat = torch.zeros_like(
self.pred_index_mat,
dtype=torch.bool if self.edge_label_weight is None else
torch.get_default_dtype(),
)
return self._pred_rel_mat
# Flatten both prediction and ground-truth indices, and determine
# overlaps afterwards via `torch.searchsorted`.
max_index = max( # type: ignore
self.pred_index_mat.max()
if self.pred_index_mat.numel() > 0 else 0,
self.edge_label_index[1].max()
if self.edge_label_index[1].numel() > 0 else 0,
) + 1
arange = torch.arange(
start=0,
end=max_index * self.pred_index_mat.size(0), # type: ignore
step=max_index, # type: ignore
device=self.pred_index_mat.device,
).view(-1, 1)
flat_pred_index = (self.pred_index_mat + arange).view(-1)
flat_label_index = max_index * self.edge_label_index[0]
flat_label_index = flat_label_index + self.edge_label_index[1]
flat_label_index, perm = flat_label_index.sort()
edge_label_weight = self.edge_label_weight
if edge_label_weight is not None:
assert edge_label_weight.size() == self.edge_label_index[0].size()
edge_label_weight = edge_label_weight[perm]
pos = torch.searchsorted(flat_label_index, flat_pred_index)
pos = pos.clamp(max=flat_label_index.size(0) - 1) # Out-of-bounds.
pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches
if edge_label_weight is not None:
pred_rel_mat = edge_label_weight[pos].where(
pred_rel_mat,
pred_rel_mat.new_zeros(1),
)
pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size())
self._pred_rel_mat = pred_rel_mat
return pred_rel_mat
@property
def label_count(self) -> Tensor:
r"""The number of ground-truth labels for every example."""
if hasattr(self, '_label_count'):
return self._label_count # type: ignore
label_count = scatter(
torch.ones_like(self.edge_label_index[0]),
self.edge_label_index[0],
dim=0,
dim_size=self.pred_index_mat.size(0),
reduce='sum',
)
self._label_count = label_count
return label_count
@property
def label_weight_sum(self) -> Tensor:
r"""The sum of edge label weights for every example."""
if self.edge_label_weight is None:
return self.label_count
if hasattr(self, '_label_weight_sum'):
return self._label_weight_sum # type: ignore
label_weight_sum = scatter(
self.edge_label_weight,
self.edge_label_index[0],
dim=0,
dim_size=self.pred_index_mat.size(0),
reduce='sum',
)
self._label_weight_sum = label_weight_sum
return label_weight_sum
@property
def edge_label_weight_pos(self) -> Optional[Tensor]:
r"""Returns the position of edge label weights in descending order
within example-wise buckets.
"""
if self.edge_label_weight is None:
return None
if hasattr(self, '_edge_label_weight_pos'):
return self._edge_label_weight_pos # type: ignore
# Get the permutation via two sorts: One globally on the weights,
# followed by a (stable) sort on the example indices.
perm1 = self.edge_label_weight.argsort(descending=True)
perm2 = self.edge_label_index[0][perm1].argsort(stable=True)
perm = perm1[perm2]
# Invert the permutation to get the final position:
pos = torch.empty_like(perm)
pos[perm] = torch.arange(perm.size(0), device=perm.device)
# Normalize position to zero within all buckets:
pos = pos - cumsum(self.label_count)[self.edge_label_index[0]]
self._edge_label_weight_pos = pos
return pos
class _LinkPredMetric(BaseMetric):
r"""An abstract class for computing link prediction retrieval metrics.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
is_differentiable: bool = False
full_state_update: bool = False
higher_is_better: Optional[bool] = None
def __init__(self, k: int) -> None:
super().__init__()
if k <= 0:
raise ValueError(f"'k' needs to be a positive integer in "
f"'{self.__class__.__name__}' (got {k})")
self.k = k
def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
r"""Updates the state variables based on the current mini-batch
prediction.
:meth:`update` can be repeated multiple times to accumulate the results
of successive predictions, *e.g.*, inside a mini-batch training or
evaluation loop.
Args:
pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
every example in the mini-batch with shape
:obj:`[batch_size, k]`.
edge_label_index (torch.Tensor): The ground-truth indices for every
example in the mini-batch, given in COO format of shape
:obj:`[2, num_ground_truth_indices]`.
edge_label_weight (torch.Tensor, optional): The weight of the
ground-truth indices for every example in the mini-batch of
shape :obj:`[num_ground_truth_indices]`. If given, needs to be
a vector of positive values. Required for weighted metrics,
ignored otherwise. (default: :obj:`None`)
"""
raise NotImplementedError
def compute(self) -> Tensor:
r"""Computes the final metric value."""
raise NotImplementedError
def reset(self) -> None:
r"""Resets metric state variables to their default value."""
if WITH_TORCHMETRICS:
super().reset()
else:
self._reset()
def _reset(self) -> None:
raise NotImplementedError
def __repr__(self) -> str:
return f'{self.__class__.__name__}(k={self.k})'
[docs]class LinkPredMetric(_LinkPredMetric):
r"""An abstract class for computing link prediction retrieval metrics.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
weighted: bool
def __init__(self, k: int) -> None:
super().__init__(k)
self.accum: Tensor
self.total: Tensor
if WITH_TORCHMETRICS:
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
else:
self.register_buffer('accum', torch.tensor(0.), persistent=False)
self.register_buffer('total', torch.tensor(0), persistent=False)
[docs] def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
if self.weighted and edge_label_weight is None:
raise ValueError(f"'edge_label_weight' is a required argument for "
f"weighted '{self.__class__.__name__}' metrics")
if not self.weighted:
edge_label_weight = None
data = LinkPredMetricData(
pred_index_mat=pred_index_mat,
edge_label_index=edge_label_index,
edge_label_weight=edge_label_weight,
)
self._update(data)
def _update(self, data: LinkPredMetricData) -> None:
metric = self._compute(data)
self.accum += metric.sum()
self.total += (data.label_count > 0).sum()
[docs] def compute(self) -> Tensor:
if self.total == 0:
return torch.zeros_like(self.accum)
return self.accum / self.total
def _compute(self, data: LinkPredMetricData) -> Tensor:
r"""Computes the specific metric.
To be implemented separately for each metric class.
Args:
data (LinkPredMetricData): The mini-batch data for computing a link
prediction metric per example.
"""
raise NotImplementedError
def _reset(self) -> None:
self.accum.zero_()
self.total.zero_()
def __repr__(self) -> str:
weighted_repr = ', weighted=True' if self.weighted else ''
return f'{self.__class__.__name__}(k={self.k}{weighted_repr})'
[docs]class LinkPredMetricCollection(torch.nn.ModuleDict):
r"""A collection of metrics to reduce and speed-up computation of link
prediction metrics.
.. code-block:: python
from torch_geometric.metrics import (
LinkPredMAP,
LinkPredMetricCollection,
LinkPredPrecision,
LinkPredRecall,
)
metrics = LinkPredMetricCollection([
LinkPredMAP(k=10),
LinkPredPrecision(k=100),
LinkPredRecall(k=50),
])
metrics.update(pred_index_mat, edge_label_index)
out = metrics.compute()
metrics.reset()
print(out)
>>> {'LinkPredMAP@10': tensor(0.375),
... 'LinkPredPrecision@100': tensor(0.127),
... 'LinkPredRecall@50': tensor(0.483)}
Args:
metrics: The link prediction metrics.
"""
def __init__(
self,
metrics: Union[
List[LinkPredMetric],
Dict[str, LinkPredMetric],
],
) -> None:
super().__init__()
if isinstance(metrics, (list, tuple)):
metrics = {
(f'{"Weighted" if getattr(metric, "weighted", False) else ""}'
f'{metric.__class__.__name__}@{metric.k}'):
metric
for metric in metrics
}
assert len(metrics) > 0
assert isinstance(metrics, dict)
for name, metric in metrics.items():
assert isinstance(metric, _LinkPredMetric)
self[name] = metric
@property
def max_k(self) -> int:
r"""The maximum number of top-:math:`k` predictions to evaluate
against.
"""
return max([metric.k for metric in self.values()])
@property
def weighted(self) -> bool:
r"""Returns :obj:`True` in case the collection holds at least one
weighted link prediction metric.
"""
return any(
[getattr(metric, 'weighted', False) for metric in self.values()])
[docs] def update( # type: ignore
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
r"""Updates the state variables based on the current mini-batch
prediction.
:meth:`update` can be repeated multiple times to accumulate the results
of successive predictions, *e.g.*, inside a mini-batch training or
evaluation loop.
Args:
pred_index_mat (torch.Tensor): The top-:math:`k` predictions of
every example in the mini-batch with shape
:obj:`[batch_size, k]`.
edge_label_index (torch.Tensor): The ground-truth indices for every
example in the mini-batch, given in COO format of shape
:obj:`[2, num_ground_truth_indices]`.
edge_label_weight (torch.Tensor, optional): The weight of the
ground-truth indices for every example in the mini-batch of
shape :obj:`[num_ground_truth_indices]`. If given, needs to be
a vector of positive values. Required for weighted metrics,
ignored otherwise. (default: :obj:`None`)
"""
if self.weighted and edge_label_weight is None:
raise ValueError(f"'edge_label_weight' is a required argument for "
f"weighted '{self.__class__.__name__}' metrics")
if not self.weighted:
edge_label_weight = None
data = LinkPredMetricData( # Share metric data across metrics.
pred_index_mat=pred_index_mat,
edge_label_index=edge_label_index,
edge_label_weight=edge_label_weight,
)
for metric in self.values():
if isinstance(metric, LinkPredMetric) and metric.weighted:
metric._update(data)
if WITH_TORCHMETRICS:
metric._update_count += 1
data.edge_label_weight = None
if hasattr(data, '_pred_rel_mat'):
data._pred_rel_mat = data._pred_rel_mat != 0.0
if hasattr(data, '_label_weight_sum'):
del data._label_weight_sum
if hasattr(data, '_edge_label_weight_pos'):
del data._edge_label_weight_pos
for metric in self.values():
if isinstance(metric, LinkPredMetric) and not metric.weighted:
metric._update(data)
if WITH_TORCHMETRICS:
metric._update_count += 1
for metric in self.values():
if not isinstance(metric, LinkPredMetric):
metric.update(pred_index_mat, edge_label_index,
edge_label_weight)
[docs] def compute(self) -> Dict[str, Tensor]:
r"""Computes the final metric values."""
return {name: metric.compute() for name, metric in self.items()}
[docs] def reset(self) -> None:
r"""Reset metric state variables to their default value."""
for metric in self.values():
metric.reset()
def __repr__(self) -> str:
names = [f' {name}: {metric},\n' for name, metric in self.items()]
return f'{self.__class__.__name__}([\n{"".join(names)}])'
[docs]class LinkPredPrecision(LinkPredMetric):
r"""A link prediction metric to compute Precision @ :math:`k`, *i.e.* the
proportion of recommendations within the top-:math:`k` that are actually
relevant.
A higher precision indicates the model's ability to surface relevant items
early in the ranking.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False
def _compute(self, data: LinkPredMetricData) -> Tensor:
pred_rel_mat = data.pred_rel_mat[:, :self.k]
return pred_rel_mat.sum(dim=-1) / self.k
[docs]class LinkPredRecall(LinkPredMetric):
r"""A link prediction metric to compute Recall @ :math:`k`, *i.e.* the
proportion of relevant items that appear within the top-:math:`k`.
A higher recall indicates the model's ability to retrieve a larger
proportion of relevant items.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
def __init__(self, k: int, weighted: bool = False):
super().__init__(k=k)
self.weighted = weighted
def _compute(self, data: LinkPredMetricData) -> Tensor:
pred_rel_mat = data.pred_rel_mat[:, :self.k]
return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7)
[docs]class LinkPredF1(LinkPredMetric):
r"""A link prediction metric to compute F1 @ :math:`k`.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False
def _compute(self, data: LinkPredMetricData) -> Tensor:
pred_rel_mat = data.pred_rel_mat[:, :self.k]
isin_count = pred_rel_mat.sum(dim=-1)
precision = isin_count / self.k
recall = isin_count / data.label_count.clamp(min=1e-7)
return 2 * precision * recall / (precision + recall).clamp(min=1e-7)
[docs]class LinkPredMAP(LinkPredMetric):
r"""A link prediction metric to compute MAP @ :math:`k` (Mean Average
Precision), considering the order of relevant items within the
top-:math:`k`.
MAP @ :math:`k` can provide a more comprehensive view of ranking quality
than precision alone.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False
def _compute(self, data: LinkPredMetricData) -> Tensor:
pred_rel_mat = data.pred_rel_mat[:, :self.k]
device = pred_rel_mat.device
arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
cum_precision = pred_rel_mat.cumsum(dim=1) / arange
return ((cum_precision * pred_rel_mat).sum(dim=-1) /
data.label_count.clamp(min=1e-7, max=self.k))
[docs]class LinkPredNDCG(LinkPredMetric):
r"""A link prediction metric to compute the NDCG @ :math:`k` (Normalized
Discounted Cumulative Gain).
In particular, can account for the position of relevant items by
considering relevance scores, giving higher weight to more relevant items
appearing at the top.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
weighted (bool, optional): If set to :obj:`True`, assumes sorted lists
of ground-truth items according to a relevance score as given by
:obj:`edge_label_weight`. (default: :obj:`False`)
"""
higher_is_better: bool = True
def __init__(self, k: int, weighted: bool = False):
super().__init__(k=k)
self.weighted = weighted
dtype = torch.get_default_dtype()
discount = torch.arange(2, k + 2, dtype=dtype).log2()
self.discount: Tensor
self.register_buffer('discount', discount, persistent=False)
if not weighted:
self.register_buffer('idcg', cumsum(1.0 / discount),
persistent=False)
else:
self.idcg = None
def _compute(self, data: LinkPredMetricData) -> Tensor:
pred_rel_mat = data.pred_rel_mat[:, :self.k]
discount = self.discount[:pred_rel_mat.size(1)].view(1, -1)
dcg = (pred_rel_mat / discount).sum(dim=-1)
if not self.weighted:
assert self.idcg is not None
idcg = self.idcg[data.label_count.clamp(max=self.k)]
else:
assert data.edge_label_weight is not None
pos = data.edge_label_weight_pos
assert pos is not None
discount = torch.cat([
self.discount,
self.discount.new_full((1, ), fill_value=float('inf')),
])
discount = discount[pos.clamp(max=self.k)]
idcg = scatter( # Apply discount and aggregate:
data.edge_label_weight / discount,
data.edge_label_index[0],
dim_size=data.pred_index_mat.size(0),
reduce='sum',
)
out = dcg / idcg
out[out.isnan() | out.isinf()] = 0.0
return out
[docs]class LinkPredMRR(LinkPredMetric):
r"""A link prediction metric to compute the MRR @ :math:`k` (Mean
Reciprocal Rank), *i.e.* the mean reciprocal rank of the first correct
prediction (or zero otherwise).
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False
def _compute(self, data: LinkPredMetricData) -> Tensor:
pred_rel_mat = data.pred_rel_mat[:, :self.k]
device = pred_rel_mat.device
arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
return (pred_rel_mat / arange).max(dim=-1)[0]
[docs]class LinkPredHitRatio(LinkPredMetric):
r"""A link prediction metric to compute the hit ratio @ :math:`k`, *i.e.*
the percentage of users for whom at least one relevant item is present
within the top-:math:`k` recommendations.
A high ratio signifies the model's effectiveness in satisfying a broad
range of user preferences.
"""
higher_is_better: bool = True
weighted: bool = False
def _compute(self, data: LinkPredMetricData) -> Tensor:
pred_rel_mat = data.pred_rel_mat[:, :self.k]
return pred_rel_mat.max(dim=-1)[0].to(torch.get_default_dtype())
[docs]class LinkPredCoverage(_LinkPredMetric):
r"""A link prediction metric to compute the Coverage @ :math:`k` of
predictions, *i.e.* the percentage of unique items recommended across all
users within the top-:math:`k`.
Higher coverage indicates a wider exploration of the item catalog.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
num_dst_nodes (int): The total number of destination nodes.
"""
higher_is_better: bool = True
def __init__(self, k: int, num_dst_nodes: int) -> None:
super().__init__(k)
self.num_dst_nodes = num_dst_nodes
self.mask: Tensor
mask = torch.zeros(num_dst_nodes, dtype=torch.bool)
if WITH_TORCHMETRICS:
self.add_state('mask', mask, dist_reduce_fx='max')
else:
self.register_buffer('mask', mask, persistent=False)
[docs] def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
self.mask[pred_index_mat[:, :self.k].flatten()] = True
[docs] def compute(self) -> Tensor:
return self.mask.to(torch.get_default_dtype()).mean()
def _reset(self) -> None:
self.mask.zero_()
def __repr__(self) -> str:
return (f'{self.__class__.__name__}(k={self.k}, '
f'num_dst_nodes={self.num_dst_nodes})')
[docs]class LinkPredDiversity(_LinkPredMetric):
r"""A link prediction metric to compute the Diversity @ :math:`k` of
predictions according to item categories.
Diversity is computed as
.. math::
div_{u@k} = 1 - \left( \frac{1}{k \cdot (k-1)} \right) \sum_{i \neq j}
sim(i, j)
where
.. math::
sim(i,j) = \begin{cases}
1 & \quad \text{if } i,j \text{ share category,}\\
0 & \quad \text{otherwise.}
\end{cases}
which measures the pair-wise inequality of recommendations according to
item categories.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
category (torch.Tensor): A vector that assigns each destination node to
a specific category.
"""
higher_is_better: bool = True
def __init__(self, k: int, category: Tensor) -> None:
super().__init__(k)
self.accum: Tensor
self.total: Tensor
if WITH_TORCHMETRICS:
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
else:
self.register_buffer('accum', torch.tensor(0.), persistent=False)
self.register_buffer('total', torch.tensor(0), persistent=False)
self.category: Tensor
self.register_buffer('category', category, persistent=False)
[docs] def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
category = self.category[pred_index_mat[:, :self.k]]
sim = (category.unsqueeze(-2) == category.unsqueeze(-1)).sum(dim=-1)
div = 1 - 1 / (self.k * (self.k - 1)) * (sim - 1).sum(dim=-1)
self.accum += div.sum()
self.total += pred_index_mat.size(0)
[docs] def compute(self) -> Tensor:
if self.total == 0:
return torch.zeros_like(self.accum)
return self.accum / self.total
def _reset(self) -> None:
self.accum.zero_()
self.total.zero_()
[docs]class LinkPredPersonalization(_LinkPredMetric):
r"""A link prediction metric to compute the Personalization @ :math:`k`,
*i.e.* the dissimilarity of recommendations across different users.
Higher personalization suggests that the model tailors recommendations to
individual user preferences rather than providing generic results.
Dissimilarity is defined by the average inverse cosine similarity between
users' lists of recommendations.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
max_src_nodes (int, optional): The maximum source nodes to consider to
compute pair-wise dissimilarity. If specified,
Personalization @ :math:`k` is approximated to avoid computation
blowup due to quadratic complexity. (default: :obj:`2**12`)
batch_size (int, optional): The batch size to determine how many pairs
of user recommendations should be processed at once.
(default: :obj:`2**16`)
"""
higher_is_better: bool = True
def __init__(
self,
k: int,
max_src_nodes: Optional[int] = 2**12,
batch_size: int = 2**16,
) -> None:
super().__init__(k)
self.max_src_nodes = max_src_nodes
self.batch_size = batch_size
self.preds: List[Tensor]
self.total: Tensor
if WITH_TORCHMETRICS:
self.add_state('preds', default=[], dist_reduce_fx='cat')
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
else:
self.preds = []
self.register_buffer('total', torch.tensor(0), persistent=False)
[docs] def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
# NOTE Move to CPU to avoid memory blowup.
pred_index_mat = pred_index_mat[:, :self.k].cpu()
if self.max_src_nodes is None:
self.preds.append(pred_index_mat)
self.total += pred_index_mat.size(0)
elif self.total < self.max_src_nodes:
remaining = int(self.max_src_nodes - self.total)
pred_index_mat = pred_index_mat[:remaining]
self.preds.append(pred_index_mat)
self.total += pred_index_mat.size(0)
[docs] def compute(self) -> Tensor:
device = self.total.device
score = torch.tensor(0.0, device=device)
total = torch.tensor(0, device=device)
if len(self.preds) == 0:
return score
pred = torch.cat(self.preds, dim=0)
if pred.size(0) == 0:
return score
# Calculate all pairs of nodes (e.g., triu_indices with offset=1).
# NOTE We do this in chunks to avoid memory blow-up, which leads to a
# more efficient but trickier implementation.
num_pairs = (pred.size(0) * (pred.size(0) - 1)) // 2
offset = torch.arange(pred.size(0) - 1, 0, -1, device=device)
rowptr = cumsum(offset)
for start in range(0, num_pairs, self.batch_size):
end = min(start + self.batch_size, num_pairs)
idx = torch.arange(start, end, device=device)
# Find the corresponding row:
row = torch.searchsorted(rowptr, idx, right=True) - 1
# Find the corresponding column:
col = idx - rowptr[row] + (pred.size(0) - offset[row])
left = pred[row.cpu()].to(device)
right = pred[col.cpu()].to(device)
# Use offset to work around applying `isin` along a specific dim:
i = max(left.max(), right.max()) + 1 # type: ignore
i = torch.arange(0, i * row.size(0), i, device=device).view(-1, 1)
isin = torch.isin(left + i, right + i)
# Compute personalization via average inverse cosine similarity:
cos = isin.sum(dim=-1) / pred.size(1)
score += (1 - cos).sum()
total += cos.numel()
return score / total
def _reset(self) -> None:
self.preds = []
self.total.zero_()
[docs]class LinkPredAveragePopularity(_LinkPredMetric):
r"""A link prediction metric to compute the Average Recommendation
Popularity (ARP) @ :math:`k`, which provides insights into the model's
tendency to recommend popular items by averaging the popularity scores of
items within the top-:math:`k` recommendations.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
popularity (torch.Tensor): The popularity of every item in the training
set, *e.g.*, the number of times an item has been rated.
"""
higher_is_better: bool = False
def __init__(self, k: int, popularity: Tensor) -> None:
super().__init__(k)
self.accum: Tensor
self.total: Tensor
if WITH_TORCHMETRICS:
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
else:
self.register_buffer('accum', torch.tensor(0.), persistent=False)
self.register_buffer('total', torch.tensor(0), persistent=False)
self.popularity: Tensor
self.register_buffer('popularity', popularity, persistent=False)
[docs] def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
pred_index_mat = pred_index_mat[:, :self.k]
popularity = self.popularity[pred_index_mat]
popularity = popularity.to(self.accum.dtype).mean(dim=-1)
self.accum += popularity.sum()
self.total += popularity.numel()
[docs] def compute(self) -> Tensor:
if self.total == 0:
return torch.zeros_like(self.accum)
return self.accum / self.total
def _reset(self) -> None:
self.accum.zero_()
self.total.zero_()