torch_geometric.nn.functional.bro
- bro(x: Tensor, batch: Tensor, p: Union[int, str] = 2) Tensor[source]
The Batch Representation Orthogonality penalty from the “Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity” paper.
Computes a regularization for each graph representation in a mini-batch according to
\[\mathcal{L}_{\textrm{BRO}}^\mathrm{graph} = || \mathbf{HH}^T - \mathbf{I}||_p\]and returns an average over all graphs in the batch.