torch_geometric.explain.metric.groundtruth_metrics
- groundtruth_metrics(pred_mask: Tensor, target_mask: Tensor, metrics: Optional[Union[str, List[str]]] = None, threshold: float = 0.5) Union[float, Tuple[float, ...]] [source]
Compares and evaluates an explanation mask with the ground-truth explanation mask.
- Parameters:
pred_mask (torch.Tensor) – The prediction mask to evaluate.
target_mask (torch.Tensor) – The ground-truth target mask.
metrics (str or List[str], optional) – The metrics to return (
"accuracy"
,"recall"
,"precision"
,"f1_score"
,"auroc"
). (default:["accuracy", "recall", "precision", "f1_score", "auroc"]
)threshold (float, optional) – The threshold value to perform hard thresholding of
mask
andgroundtruth
. (default:0.5
)
- Return type: