Source code for torch_geometric.explain.metric.basic

from typing import List, Optional, Tuple, Union

from torch import Tensor

METRICS = ['accuracy', 'recall', 'precision', 'f1_score', 'auroc']


[docs]def groundtruth_metrics( pred_mask: Tensor, target_mask: Tensor, metrics: Optional[Union[str, List[str]]] = None, threshold: float = 0.5, ) -> Union[float, Tuple[float, ...]]: r"""Compares and evaluates an explanation mask with the ground-truth explanation mask. Args: pred_mask (torch.Tensor): The prediction mask to evaluate. target_mask (torch.Tensor): The ground-truth target mask. metrics (str or List[str], optional): The metrics to return (:obj:`"accuracy"`, :obj:`"recall"`, :obj:`"precision"`, :obj:`"f1_score"`, :obj:`"auroc"`). (default: :obj:`["accuracy", "recall", "precision", "f1_score", "auroc"]`) threshold (float, optional): The threshold value to perform hard thresholding of :obj:`mask` and :obj:`groundtruth`. (default: :obj:`0.5`) """ import torchmetrics if metrics is None: metrics = METRICS if isinstance(metrics, str): metrics = [metrics] if not isinstance(metrics, (tuple, list)): raise ValueError(f"Expected metrics to be a string or a list of " f"strings (got {type(metrics)})") pred_mask = pred_mask.view(-1) target_mask = (target_mask >= threshold).view(-1) outs = [] for metric in metrics: if metric not in METRICS: raise ValueError(f"Encountered invalid metric {metric}") fn = getattr(torchmetrics.functional, metric) if metric in {'auroc'}: out = fn(pred_mask, target_mask, 'binary') else: out = fn(pred_mask, target_mask, 'binary', threshold) outs.append(float(out)) return tuple(outs) if len(outs) > 1 else outs[0]