torch_geometric.metrics.LinkPredRecall

class LinkPredRecall(k: int)[source]

Bases: LinkPredMetric

A link prediction metric to compute Recall @ \(k\).

Parameters:

k (int) – The number of top-\(k\) predictions to evaluate against.

update(pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]) None

Updates the state variables based on the current mini-batch prediction.

update() can be repeated multiple times to accumulate the results of successive predictions, e.g., inside a mini-batch training or evaluation loop.

Parameters:
  • pred_index_mat (torch.Tensor) – The top-\(k\) predictions of every example in the mini-batch with shape [batch_size, k].

  • edge_label_index (torch.Tensor) – The ground-truth indices for every example in the mini-batch, given in COO format of shape [2, num_ground_truth_indices].

compute() Tensor

Computes the final metric value.

reset() None

Reset metric state variables to their default value.