torch_geometric.explain.metric.groundtruth_metrics

groundtruth_metrics(pred_mask: Tensor, target_mask: Tensor, metrics: Optional[Union[str, List[str]]] = None, threshold: float = 0.5) Union[float, Tuple[float, ...]][source]

Compares and evaluates an explanation mask with the ground-truth explanation mask.

Parameters:
  • pred_mask (torch.Tensor) – The prediction mask to evaluate.

  • target_mask (torch.Tensor) – The ground-truth target mask.

  • metrics (str or List[str], optional) – The metrics to return ("accuracy", "recall", "precision", "f1_score", "auroc"). (default: ["accuracy", "recall", "precision", "f1_score", "auroc"])

  • threshold (float, optional) – The threshold value to perform hard thresholding of mask and groundtruth. (default: 0.5)