from typing import Optional, Tuple
import torch
from torch import Tensor
import torch.nn.functional as F
from torch_scatter import scatter_add
[docs]def accuracy(pred: Tensor, target: Tensor) -> float:
r"""Computes the accuracy of predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
:rtype: float
"""
return int((pred == target).sum()) / target.numel()
[docs]def true_positive(pred: Tensor, target: Tensor, num_classes: int) -> Tensor:
r"""Computes the number of true positive predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`LongTensor`
"""
out = []
for i in range(num_classes):
out.append(((pred == i) & (target == i)).sum())
return torch.tensor(out, device=pred.device)
[docs]def true_negative(pred: Tensor, target: Tensor, num_classes: int) -> Tensor:
r"""Computes the number of true negative predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`LongTensor`
"""
out = []
for i in range(num_classes):
out.append(((pred != i) & (target != i)).sum())
return torch.tensor(out, device=pred.device)
[docs]def false_positive(pred: Tensor, target: Tensor, num_classes: int) -> Tensor:
r"""Computes the number of false positive predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`LongTensor`
"""
out = []
for i in range(num_classes):
out.append(((pred == i) & (target != i)).sum())
return torch.tensor(out, device=pred.device)
[docs]def false_negative(pred: Tensor, target: Tensor, num_classes: int) -> Tensor:
r"""Computes the number of false negative predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`LongTensor`
"""
out = []
for i in range(num_classes):
out.append(((pred != i) & (target == i)).sum())
return torch.tensor(out, device=pred.device)
[docs]def precision(pred: Tensor, target: Tensor, num_classes: int) -> Tensor:
r"""Computes the precision
:math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FP}}` of predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`Tensor`
"""
tp = true_positive(pred, target, num_classes).to(torch.float)
fp = false_positive(pred, target, num_classes).to(torch.float)
out = tp / (tp + fp)
out[torch.isnan(out)] = 0
return out
[docs]def recall(pred: Tensor, target: Tensor, num_classes: int) -> Tensor:
r"""Computes the recall
:math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FN}}` of predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`Tensor`
"""
tp = true_positive(pred, target, num_classes).to(torch.float)
fn = false_negative(pred, target, num_classes).to(torch.float)
out = tp / (tp + fn)
out[torch.isnan(out)] = 0
return out
[docs]def f1_score(pred: Tensor, target: Tensor, num_classes: int) -> Tensor:
r"""Computes the :math:`F_1` score
:math:`2 \cdot \frac{\mathrm{precision} \cdot \mathrm{recall}}
{\mathrm{precision}+\mathrm{recall}}` of predictions.
Args:
pred (Tensor): The predictions.
target (Tensor): The targets.
num_classes (int): The number of classes.
:rtype: :class:`Tensor`
"""
prec = precision(pred, target, num_classes)
rec = recall(pred, target, num_classes)
score = 2 * (prec * rec) / (prec + rec)
score[torch.isnan(score)] = 0
return score
[docs]def intersection_and_union(
pred: Tensor, target: Tensor, num_classes: int,
batch: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
r"""Computes intersection and union of predictions.
Args:
pred (LongTensor): The predictions.
target (LongTensor): The targets.
num_classes (int): The number of classes.
batch (LongTensor): The assignment vector which maps each pred-target
pair to an example.
:rtype: (:class:`LongTensor`, :class:`LongTensor`)
"""
pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes)
if batch is None:
i = (pred & target).sum(dim=0)
u = (pred | target).sum(dim=0)
else:
i = scatter_add(pred & target, batch, dim=0)
u = scatter_add(pred | target, batch, dim=0)
return i, u
[docs]def mean_iou(pred: Tensor, target: Tensor, num_classes: int,
batch: Optional[Tensor] = None, omitnans: bool = False) -> Tensor:
r"""Computes the mean intersection over union score of predictions.
Args:
pred (LongTensor): The predictions.
target (LongTensor): The targets.
num_classes (int): The number of classes.
batch (LongTensor): The assignment vector which maps each pred-target
pair to an example.
omitnans (bool, optional): If set to :obj:`True`, will ignore any
:obj:`NaN` values encountered during computation. Otherwise, will
treat them as :obj:`1`. (default: :obj:`False`)
:rtype: :class:`Tensor`
"""
i, u = intersection_and_union(pred, target, num_classes, batch)
iou = i.to(torch.float) / u.to(torch.float)
if omitnans:
iou = iou[~iou.isnan()].mean()
else:
iou[torch.isnan(iou)] = 1.
iou = iou.mean(dim=-1)
return iou