torch_geometric.nn.functional.bro
- bro(x: Tensor, batch: Tensor, p: Union[int, str] = 2) Tensor [source]
The Batch Representation Orthogonality penalty from the “Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity” paper.
Computes a regularization for each graph representation in a mini-batch according to
\[\mathcal{L}_{\textrm{BRO}}^\mathrm{graph} = || \mathbf{HH}^T - \mathbf{I}||_p\]and returns an average over all graphs in the batch.
- Parameters:
x (torch.Tensor) – The node feature matrix.
batch (torch.Tensor) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example.
p (int or str, optional) – The norm order to use. (default:
2
)