torch_geometric.metrics.LinkPredMetricCollection

class LinkPredMetricCollection(metrics: Union[List[LinkPredMetric], Dict[str, LinkPredMetric]])[source]

Bases: ModuleDict

A collection of metrics to reduce and speed-up computation of link prediction metrics.

from torch_geometric.metrics import (
    LinkPredMAP,
    LinkPredMetricCollection,
    LinkPredPrecision,
    LinkPredRecall,
)

metrics = LinkPredMetricCollection([
    LinkPredMAP(k=10),
    LinkPredPrecision(k=100),
    LinkPredRecall(k=50),
])

metrics.update(pred_index_mat, edge_label_index)
out = metrics.compute()
metrics.reset()

print(out)
>>> {'LinkPredMAP@10': tensor(0.375),
...  'LinkPredPrecision@100': tensor(0.127),
...  'LinkPredRecall@50': tensor(0.483)}
Parameters:

metrics (Union[List[LinkPredMetric], Dict[str, LinkPredMetric]]) – The link prediction metrics.

update(pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], edge_label_weight: Optional[Tensor] = None) None[source]

Updates the state variables based on the current mini-batch prediction.

update() can be repeated multiple times to accumulate the results of successive predictions, e.g., inside a mini-batch training or evaluation loop.

Parameters:
  • pred_index_mat (torch.Tensor) – The top-\(k\) predictions of every example in the mini-batch with shape [batch_size, k].

  • edge_label_index (torch.Tensor) – The ground-truth indices for every example in the mini-batch, given in COO format of shape [2, num_ground_truth_indices].

  • edge_label_weight (torch.Tensor, optional) – The weight of the ground-truth indices for every example in the mini-batch of shape [num_ground_truth_indices]. If given, needs to be a vector of positive values. Required for weighted metrics, ignored otherwise. (default: None)

Return type:

None

compute() Dict[str, Tensor][source]

Computes the final metric values.

Return type:

Dict[str, Tensor]

reset() None[source]

Reset metric state variables to their default value.

Return type:

None