Source code for torch_geometric.utils.metric

from __future__ import division

import torch
import torch.nn.functional as F
from torch_scatter import scatter_add


[docs]def accuracy(pred, target): r"""Computes the accuracy of correct predictions. Args: pred (Tensor): The predictions. target (Tensor): The targets. :rtype: int """ return (pred == target).sum().item() / target.numel()
[docs]def true_positive(pred, target, num_classes): r"""Computes the number of true positive predictions. Args: pred (Tensor): The predictions. target (Tensor): The targets. num_classes (int): The number of classes. :rtype: :class:`LongTensor` """ out = [] for i in range(num_classes): out.append(((pred == i) & (target == i)).sum()) return torch.tensor(out)
[docs]def true_negative(pred, target, num_classes): r"""Computes the number of true negative predictions. Args: pred (Tensor): The predictions. target (Tensor): The targets. num_classes (int): The number of classes. :rtype: :class:`LongTensor` """ out = [] for i in range(num_classes): out.append(((pred != i) & (target != i)).sum()) return torch.tensor(out)
[docs]def false_positive(pred, target, num_classes): r"""Computes the number of false positive predictions. Args: pred (Tensor): The predictions. target (Tensor): The targets. num_classes (int): The number of classes. :rtype: :class:`LongTensor` """ out = [] for i in range(num_classes): out.append(((pred == i) & (target != i)).sum()) return torch.tensor(out)
[docs]def false_negative(pred, target, num_classes): r"""Computes the number of false negative predictions. Args: pred (Tensor): The predictions. target (Tensor): The targets. num_classes (int): The number of classes. :rtype: :class:`LongTensor` """ out = [] for i in range(num_classes): out.append(((pred != i) & (target == i)).sum()) return torch.tensor(out)
[docs]def precision(pred, target, num_classes): r"""Computes the precision: :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FP}}`. Args: pred (Tensor): The predictions. target (Tensor): The targets. num_classes (int): The number of classes. :rtype: :class:`Tensor` """ tp = true_positive(pred, target, num_classes).to(torch.float) fp = false_positive(pred, target, num_classes).to(torch.float) out = tp / (tp + fp) out[torch.isnan(out)] = 0 return out
[docs]def recall(pred, target, num_classes): r"""Computes the recall: :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FN}}`. Args: pred (Tensor): The predictions. target (Tensor): The targets. num_classes (int): The number of classes. :rtype: :class:`Tensor` """ tp = true_positive(pred, target, num_classes).to(torch.float) fn = false_negative(pred, target, num_classes).to(torch.float) out = tp / (tp + fn) out[torch.isnan(out)] = 0 return out
[docs]def f1_score(pred, target, num_classes): r"""Computes the :math:`F_1` score: :math:`2 \cdot \frac{\mathrm{precision} \cdot \mathrm{recall}} {\mathrm{precision}+\mathrm{recall}}`. Args: pred (Tensor): The predictions. target (Tensor): The targets. num_classes (int): The number of classes. :rtype: :class:`Tensor` """ prec = precision(pred, target, num_classes) rec = recall(pred, target, num_classes) score = 2 * (prec * rec) / (prec + rec) score[torch.isnan(score)] = 0 return score
[docs]def mean_iou(pred, target, num_classes, batch=None): r"""Computes the mean Intersection over Union score. Args: pred (LongTensor): The predictions. target (LongTensor): The targets. num_classes (int): The number of classes. batch (LongTensor): The assignment vector which maps each pred-target pair to an example. :rtype: :class:`Tensor` """ pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes) if batch is not None: i = scatter_add(pred & target, batch, dim=0).to(torch.float) u = scatter_add(pred | target, batch, dim=0).to(torch.float) else: i = (pred & target).sum(dim=0).to(torch.float) u = (pred | target).sum(dim=0).to(torch.float) iou = i / u iou[torch.isnan(iou)] = 1 iou = iou.mean(dim=-1) return iou