Source code for torch_geometric.graphgym.loss

import torch
import torch.nn.functional as F

import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg


[docs]def compute_loss(pred, true): """Compute loss and prediction score. Args: pred (torch.tensor): Unnormalized prediction true (torch.tensor): Grou Returns: Loss, normalized prediction score """ bce_loss = torch.nn.BCEWithLogitsLoss(reduction=cfg.model.size_average) mse_loss = torch.nn.MSELoss(reduction=cfg.model.size_average) # default manipulation for pred and true # can be skipped if special loss computation is needed pred = pred.squeeze(-1) if pred.ndim > 1 else pred true = true.squeeze(-1) if true.ndim > 1 else true # Try to load customized loss for func in register.loss_dict.values(): value = func(pred, true) if value is not None: return value if cfg.model.loss_fun == 'cross_entropy': # multiclass if pred.ndim > 1 and true.ndim == 1: pred = F.log_softmax(pred, dim=-1) return F.nll_loss(pred, true), pred # binary or multilabel else: true = true.float() return bce_loss(pred, true), torch.sigmoid(pred) elif cfg.model.loss_fun == 'mse': true = true.float() return mse_loss(pred, true), pred else: raise ValueError(f"Loss function '{cfg.model.loss_fun}' not supported")