from __future__ import annotations
import torch
from torch import nn
from torch_geometric.data import Batch, Data
from torch_geometric.utils import scatter
__all__ = [
'MGNAN',
]
def _init_weights(module: nn.Module, std: float = 1.0) -> None:
"""Utility that mimics Xavier initialisation with configurable std."""
for name, param in module.named_parameters():
if 'weight' in name:
nn.init.xavier_normal_(param, gain=std)
elif 'bias' in name:
nn.init.constant_(param, 0.0)
class _PerFeatureMLP(nn.Module):
"""Simple MLP that is applied to a single scalar feature.
Args:
out_channels (int): Output dimension per feature ("f" in the paper).
n_layers (int): Number of layers. If ``1``, the MLP is a single Linear.
hidden_channels (int, optional): Hidden dimension. Required when
``n_layers > 1``.
bias (bool, optional): Use bias terms. (default: ``True``)
dropout (float, optional): Dropout probability after hidden layers.
(default: ``0.0``)
"""
def __init__(
self,
out_channels: int,
n_layers: int,
hidden_channels: int | None = None,
*,
bias: bool = True,
dropout: float = 0.0,
) -> None:
super().__init__()
if n_layers == 1:
self.net = nn.Linear(1, out_channels, bias=bias)
else:
assert hidden_channels is not None
layers: list[nn.Module] = [
nn.Linear(1, hidden_channels, bias=bias),
nn.ReLU(),
nn.Dropout(dropout),
]
for _ in range(1, n_layers - 1):
layers += [
nn.Linear(hidden_channels, hidden_channels, bias=bias),
nn.ReLU(),
nn.Dropout(dropout),
]
layers.append(nn.Linear(hidden_channels, out_channels, bias=bias))
self.net = nn.Sequential(*layers)
_init_weights(self)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [N]
return self.net(x.view(-1, 1)) # [N, out_channels]
class _MultiFeatureMLP(nn.Module):
"""MLP that processes multiple features together.
Args:
in_channels (int): Number of input features to process together.
out_channels (int): Output dimension per feature group.
n_layers (int): Number of layers. If ``1``, the MLP is a single Linear.
hidden_channels (int, optional): Hidden dimension. Required when
``n_layers > 1``.
bias (bool, optional): Use bias terms. (default: ``True``)
dropout (float, optional): Dropout probability after hidden layers.
(default: ``0.0``)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
n_layers: int,
hidden_channels: int | None = None,
*,
bias: bool = True,
dropout: float = 0.0,
) -> None:
super().__init__()
if n_layers == 1:
self.net = nn.Linear(in_channels, out_channels, bias=bias)
else:
assert hidden_channels is not None
layers: list[nn.Module] = [
nn.Linear(in_channels, hidden_channels, bias=bias),
nn.ReLU(),
nn.Dropout(dropout),
]
for _ in range(1, n_layers - 1):
layers += [
nn.Linear(hidden_channels, hidden_channels, bias=bias),
nn.ReLU(),
nn.Dropout(dropout),
]
layers.append(nn.Linear(hidden_channels, out_channels, bias=bias))
self.net = nn.Sequential(*layers)
_init_weights(self)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [N, in_channels]
return self.net(x) # [N, out_channels]
class _RhoMLP(nn.Module):
"""MLP that turns a scalar distance into a scalar or vector weight."""
def __init__(
self,
out_channels: int,
n_layers: int,
hidden_channels: int | None = None,
*,
bias: bool = True,
dropout: float = 0.0,
) -> None:
super().__init__()
if n_layers == 1:
self.net = nn.Linear(1, out_channels, bias=bias)
else:
assert hidden_channels is not None
layers: list[nn.Module] = [
nn.Linear(1, hidden_channels, bias=bias),
nn.ReLU()
]
for _ in range(1, n_layers - 1):
layers += [
nn.Linear(hidden_channels, hidden_channels, bias=bias),
nn.ReLU(),
]
layers.append(nn.Linear(hidden_channels, out_channels, bias=bias))
self.net = nn.Sequential(*layers)
_init_weights(self)
def forward(self, d: torch.Tensor) -> torch.Tensor: # [...]
return self.net(d.view(-1, 1))
[docs]class MGNAN(nn.Module):
r"""M-GNAN: an extension of the
`Graph Neural Additive Network (GNAN)
<https://arxiv.org/abs/2406.01317>`_ to *multivariate* shape functions.
Whereas the original GNAN learns a univariate shape function per input
feature, M-GNAN allows arbitrary groups of features to be modelled jointly
by a single multivariate shape function (an MLP that takes all features in
the group as input). The univariate per-feature case is recovered when
every feature lives in its own group, which is the default behaviour.
By default it aggregates node scores to produce *graph‐level* predictions
(shape ``[batch_size, out_channels]``). Set ``graph_level=False`` to
obtain *node‐level* predictions instead, in which case the forward returns
a tensor of shape ``[num_nodes, out_channels]``.
Args:
in_channels (int): Number of input node features.
out_channels (int): Output dimension.
n_layers (int): Number of layers in the MLPs.
hidden_channels (int, optional): Hidden dimension in the MLPs.
bias (bool, optional): Use bias terms. (default: ``True``)
dropout (float, optional): Dropout probability. (default: ``0.0``)
normalize_rho (bool, optional): Whether to normalize rho weights.
(default: ``True``)
graph_level (bool, optional): Whether to produce graph-level
predictions. (default: ``True``)
feature_groups (List[List[int]], optional): Groups of feature indices
to process together. Each group will be processed by a single MLP
that takes multiple features as input. If None, each feature is
processed by its own MLP (default behavior). (default: ``None``)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
n_layers: int,
*,
hidden_channels: int | None = None,
bias: bool = True,
dropout: float = 0.0,
normalize_rho: bool = True,
graph_level: bool = True,
feature_groups: list[list[int]] | None = None,
) -> None:
super().__init__()
self.normalize_rho = normalize_rho
self.graph_level = graph_level
self.out_channels = out_channels
self.in_channels = in_channels
# Set up feature groups - default is each feature in its own group
if feature_groups is None:
self.feature_groups = [[i] for i in range(in_channels)]
else:
self.feature_groups = feature_groups
# Validate feature groups
all_features = set()
for group in feature_groups:
if not group:
raise ValueError("Feature groups cannot be empty")
for feat_idx in group:
if feat_idx < 0 or feat_idx >= in_channels:
raise ValueError(
f"Feature index {feat_idx} out of range "
f"[0, {in_channels})")
if feat_idx in all_features:
raise ValueError(
f"Feature index {feat_idx} appears in "
"multiple groups")
all_features.add(feat_idx)
if len(all_features) != in_channels:
missing = set(range(in_channels)) - all_features
raise ValueError(
f"Missing feature indices in groups: {missing}")
# Create MLPs for each feature group
self.fs = nn.ModuleList()
for group in self.feature_groups:
group_size = len(group)
if group_size == 1:
# Single feature - use original MLP
mlp = _PerFeatureMLP(out_channels, n_layers, hidden_channels,
bias=bias, dropout=dropout)
else:
# Multiple features - use new multi-feature MLP
mlp = _MultiFeatureMLP(group_size, out_channels, n_layers,
hidden_channels, bias=bias,
dropout=dropout)
self.fs.append(mlp)
self.rho = _RhoMLP(out_channels, n_layers, hidden_channels, bias=True)
def _process_feature_groups(self, x: torch.Tensor):
"""Process features according to groups and return fx and f_sum."""
fx_list: list[torch.Tensor] = []
for group, mlp in zip(self.feature_groups, self.fs):
if len(group) == 1:
feat_tensor = x[:, group[0]] # [N]
else:
feat_tensor = x[:, group] # [N, |group|]
fx_list.append(mlp(feat_tensor)) # [N, C]
fx = torch.stack(fx_list, dim=1) # [N, num_groups, C]
f_sum = fx.sum(dim=1) # [N, C]
return fx, f_sum
def _compute_rho(self, dist: torch.Tensor, norm: torch.Tensor,
data) -> torch.Tensor:
"""Compute rho tensor, including normalization and masking."""
x = data.x # type: ignore
inv_dist = 1.0 / (1.0 + dist) # [N, N]
rho = self.rho(inv_dist.flatten().view(-1, 1)) # [(N*N), C]
rho = rho.view(x.size(0), x.size(0), self.out_channels) # [N, N, C]
if self.normalize_rho:
norm_safe = norm.clone()
norm_safe[norm_safe == 0] = 1.0
rho = rho / norm_safe.unsqueeze(-1) # broadcast division
if hasattr(data, 'batch') and data.batch is not None:
batch_i = data.batch.view(-1, 1) # type: ignore
batch_j = data.batch.view(1, -1) # type: ignore
mask = (batch_i == batch_j).unsqueeze(-1)
rho = rho * mask
return rho
[docs] def forward(
self,
data: Data | Batch,
) -> torch.Tensor:
x: torch.Tensor = data.x # type: ignore # [N, F]
dist: torch.Tensor = data.node_distances # type: ignore # [N, N]
norm: torch.Tensor = data.normalization_matrix # type: ignore # [N, N]
_, f_sum = self._process_feature_groups(x)
rho = self._compute_rho(dist, norm, data)
# Perform Σ_i Σ_j ρ(d_ij) Σ_k f_k(x_jk)
out = torch.einsum('ijc,jc->ic', rho, f_sum) # [N, C]
if self.graph_level:
batch = data.batch # type: ignore
if batch is not None:
graph_out = scatter(out, batch, dim=0, reduce='add')
else:
graph_out = out.sum(dim=0, keepdim=True) # [1, C]
return graph_out
return out
[docs] def node_importance(self, data: Data | Batch) -> torch.Tensor:
"""Returns the contribution of every node to the
graph‐level prediction. UsingEq. (3) in the paper.
"""
x: torch.Tensor = data.x # type: ignore # [N, F]
dist: torch.Tensor = data.node_distances # type: ignore # [N, N]
norm: torch.Tensor = data.normalization_matrix # type: ignore
# [N, N]
_, f_sum = self._process_feature_groups(x)
rho = self._compute_rho(dist, norm, data)
# Aggregate over *receiver* nodes i to obtain \sum_i rho(d_{ij}).
rho_sum_over_i = rho.sum(dim=0) # [N, C]
# Node contribution s_j = (sum_k f_k(x_jk)) * (sum_i rho(d_{ij})).
node_contrib = f_sum * rho_sum_over_i # [N, C]
return node_contrib