Source code for torch_geometric.nn.conv.hypergraph_conv

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import scatter, softmax


[docs]class HypergraphConv(MessagePassing): r"""The hypergraph convolutional operator from the `"Hypergraph Convolution and Hypergraph Attention" <https://arxiv.org/abs/1901.08150>`_ paper. .. math:: \mathbf{X}^{\prime} = \mathbf{D}^{-1} \mathbf{H} \mathbf{W} \mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{\Theta} where :math:`\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}` is the incidence matrix, :math:`\mathbf{W} \in \mathbb{R}^M` is the diagonal hyperedge weight matrix, and :math:`\mathbf{D}` and :math:`\mathbf{B}` are the corresponding degree matrices. For example, in the hypergraph scenario :math:`\mathcal{G} = (\mathcal{V}, \mathcal{E})` with :math:`\mathcal{V} = \{ 0, 1, 2, 3 \}` and :math:`\mathcal{E} = \{ \{ 0, 1, 2 \}, \{ 1, 2, 3 \} \}`, the :obj:`hyperedge_index` is represented as: .. code-block:: python hyperedge_index = torch.tensor([ [0, 1, 2, 1, 2, 3], [0, 0, 0, 1, 1, 1], ]) Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. use_attention (bool, optional): If set to :obj:`True`, attention will be added to this layer. (default: :obj:`False`) attention_mode (str, optional): The mode on how to compute attention. If set to :obj:`"node"`, will compute attention scores of nodes within all nodes belonging to the same hyperedge. If set to :obj:`"edge"`, will compute attention scores of nodes across all edges holding this node belongs to. (default: :obj:`"node"`) heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, hyperedge indices :math:`(|\mathcal{V}|, |\mathcal{E}|)`, hyperedge weights :math:`(|\mathcal{E}|)` *(optional)* hyperedge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__( self, in_channels: int, out_channels: int, use_attention: bool = False, attention_mode: str = 'node', heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0, bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(flow='source_to_target', node_dim=0, **kwargs) assert attention_mode in ['node', 'edge'] self.in_channels = in_channels self.out_channels = out_channels self.use_attention = use_attention self.attention_mode = attention_mode if self.use_attention: self.heads = heads self.concat = concat self.negative_slope = negative_slope self.dropout = dropout self.lin = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') self.att = Parameter(torch.empty(1, heads, 2 * out_channels)) else: self.heads = 1 self.concat = True self.lin = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot') if bias and concat: self.bias = Parameter(torch.empty(heads * out_channels)) elif bias and not concat: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() self.lin.reset_parameters() if self.use_attention: glorot(self.att) zeros(self.bias)
[docs] @disable_dynamic_shapes(required_args=['num_edges']) def forward(self, x: Tensor, hyperedge_index: Tensor, hyperedge_weight: Optional[Tensor] = None, hyperedge_attr: Optional[Tensor] = None, num_edges: Optional[int] = None) -> Tensor: r"""Runs the forward pass of the module. Args: x (torch.Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. hyperedge_index (torch.Tensor): The hyperedge indices, *i.e.* the sparse incidence matrix :math:`\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}` mapping from nodes to edges. hyperedge_weight (torch.Tensor, optional): Hyperedge weights :math:`\mathbf{W} \in \mathbb{R}^M`. (default: :obj:`None`) hyperedge_attr (torch.Tensor, optional): Hyperedge feature matrix in :math:`\mathbb{R}^{M \times F}`. These features only need to get passed in case :obj:`use_attention=True`. (default: :obj:`None`) num_edges (int, optional) : The number of edges :math:`M`. (default: :obj:`None`) """ num_nodes = x.size(0) if num_edges is None: num_edges = 0 if hyperedge_index.numel() > 0: num_edges = int(hyperedge_index[1].max()) + 1 if hyperedge_weight is None: hyperedge_weight = x.new_ones(num_edges) x = self.lin(x) alpha = None if self.use_attention: assert hyperedge_attr is not None x = x.view(-1, self.heads, self.out_channels) hyperedge_attr = self.lin(hyperedge_attr) hyperedge_attr = hyperedge_attr.view(-1, self.heads, self.out_channels) x_i = x[hyperedge_index[0]] x_j = hyperedge_attr[hyperedge_index[1]] alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) alpha = F.leaky_relu(alpha, self.negative_slope) if self.attention_mode == 'node': alpha = softmax(alpha, hyperedge_index[1], num_nodes=num_edges) else: alpha = softmax(alpha, hyperedge_index[0], num_nodes=num_nodes) alpha = F.dropout(alpha, p=self.dropout, training=self.training) D = scatter(hyperedge_weight[hyperedge_index[1]], hyperedge_index[0], dim=0, dim_size=num_nodes, reduce='sum') D = 1.0 / D D[D == float("inf")] = 0 B = scatter(x.new_ones(hyperedge_index.size(1)), hyperedge_index[1], dim=0, dim_size=num_edges, reduce='sum') B = 1.0 / B B[B == float("inf")] = 0 out = self.propagate(hyperedge_index, x=x, norm=B, alpha=alpha, size=(num_nodes, num_edges)) out = self.propagate(hyperedge_index.flip([0]), x=out, norm=D, alpha=alpha, size=(num_edges, num_nodes)) if self.concat is True: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out = out + self.bias return out
def message(self, x_j: Tensor, norm_i: Tensor, alpha: Tensor) -> Tensor: H, F = self.heads, self.out_channels out = norm_i.view(-1, 1, 1) * x_j.view(-1, H, F) if alpha is not None: out = alpha.view(-1, self.heads, 1) * out return out