# Source code for torch_geometric.nn.functional.bro

from typing import Union

import torch

[docs]def bro( x: torch.Tensor, batch: torch.Tensor, p: Union[int, str] = 2, ) -> torch.Tensor: r"""The Batch Representation Orthogonality penalty from the "Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity" <https://arxiv.org/abs/2105.04854>_ paper Computes a regularization for each graph representation in a minibatch according to: .. math:: \mathcal{L}_{\textrm{BRO}}^\mathrm{graph} = || \mathbf{HH}^T - \mathbf{I}||_p And returns an average over all graphs in the batch. Args: x (torch.Tensor): node-wise feature tensor batch (torch.Tensor): one-dimensional tensor indicating node membership within batch p (str or int): order of the norm. See torch.norm documentation <https://pytorch.org/docs/stable/generated/torch.norm.html>_ Returns: average BRO penalty in the minibatch """ _, counts = torch.unique(batch, return_counts=True) diags = torch.stack( [ torch.diag(x) for x in torch.nn.utils.rnn.pad_sequence( sequences=torch.ones_like(batch).split_with_sizes( counts.tolist()), padding_value=0., batch_first=True, ) ] ) x = x.split_with_sizes(split_sizes=counts.tolist()) x = torch.nn.utils.rnn.pad_sequence( sequences=x, padding_value=0., batch_first=True, ) return torch.sum( torch.norm(x @ x.transpose(1, 2) - diags, p=p, dim=(1, 2)) ) / counts.shape[0]