torch_geometric.metrics.LinkPredMetricCollection
- class LinkPredMetricCollection(metrics: Union[List[LinkPredMetric], Dict[str, LinkPredMetric]])[source]
Bases:
ModuleDict
A collection of metrics to reduce and speed-up computation of link prediction metrics.
from torch_geometric.metrics import ( LinkPredMAP, LinkPredMetricCollection, LinkPredPrecision, LinkPredRecall, ) metrics = LinkPredMetricCollection([ LinkPredMAP(k=10), LinkPredPrecision(k=100), LinkPredRecall(k=50), ]) metrics.update(pred_index_mat, edge_label_index) out = metrics.compute() metrics.reset() print(out) >>> {'LinkPredMAP@10': tensor(0.375), ... 'LinkPredPrecision@100': tensor(0.127), ... 'LinkPredRecall@50': tensor(0.483)}
- Parameters:
metrics (
Union
[List
[LinkPredMetric
],Dict
[str
,LinkPredMetric
]]) – The link prediction metrics.
- update(pred_index_mat: Tensor, edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]], edge_label_weight: Optional[Tensor] = None) None [source]
Updates the state variables based on the current mini-batch prediction.
update()
can be repeated multiple times to accumulate the results of successive predictions, e.g., inside a mini-batch training or evaluation loop.- Parameters:
pred_index_mat (torch.Tensor) – The top-\(k\) predictions of every example in the mini-batch with shape
[batch_size, k]
.edge_label_index (torch.Tensor) – The ground-truth indices for every example in the mini-batch, given in COO format of shape
[2, num_ground_truth_indices]
.edge_label_weight (torch.Tensor, optional) – The weight of the ground-truth indices for every example in the mini-batch of shape
[num_ground_truth_indices]
. If given, needs to be a vector of positive values. Required for weighted metrics, ignored otherwise. (default:None
)
- Return type: