bro(x: Tensor, batch: Tensor, p: Union[int, str] = 2) Tensor[source]

The Batch Representation Orthogonality penalty from the “Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity” paper.

Computes a regularization for each graph representation in a mini-batch according to

\[\mathcal{L}_{\textrm{BRO}}^\mathrm{graph} = || \mathbf{HH}^T - \mathbf{I}||_p\]

and returns an average over all graphs in the batch.

  • x (torch.Tensor) – The node feature matrix.

  • batch (torch.Tensor) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.

  • p (int or str, optional) – The norm order to use. (default: 2)