Source code for torch_geometric.nn.conv.antisymmetric_conv

import math
from typing import Any, Callable, Dict, Optional, Union

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import GCNConv, MessagePassing
from torch_geometric.nn.inits import zeros
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import Adj


[docs]class AntiSymmetricConv(torch.nn.Module): r"""The anti-symmetric graph convolutional operator from the `"Anti-Symmetric DGN: a stable architecture for Deep Graph Networks" <https://openreview.net/forum?id=J3Y7cgZOOS>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{x}_i + \epsilon \cdot \sigma \left( (\mathbf{W}-\mathbf{W}^T-\gamma \mathbf{I}) \mathbf{x}_i + \Phi(\mathbf{X}, \mathcal{N}_i) + \mathbf{b}\right), where :math:`\Phi(\mathbf{X}, \mathcal{N}_i)` denotes a :class:`~torch.nn.conv.MessagePassing` layer. Args: in_channels (int): Size of each input sample. phi (MessagePassing, optional): The message passing module :math:`\Phi`. If set to :obj:`None`, will use a :class:`~torch_geometric.nn.conv.GCNConv` layer as default. (default: :obj:`None`) num_iters (int, optional): The number of times the anti-symmetric deep graph network operator is called. (default: :obj:`1`) epsilon (float, optional): The discretization step size :math:`\epsilon`. (default: :obj:`0.1`) gamma (float, optional): The strength of the diffusion :math:`\gamma`. It regulates the stability of the method. (default: :obj:`0.1`) act (str, optional): The non-linear activation function :math:`\sigma`, *e.g.*, :obj:`"tanh"` or :obj:`"relu"`. (default: :class:`"tanh"`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{in})` """ def __init__( self, in_channels: int, phi: Optional[MessagePassing] = None, num_iters: int = 1, epsilon: float = 0.1, gamma: float = 0.1, act: Union[str, Callable, None] = 'tanh', act_kwargs: Optional[Dict[str, Any]] = None, bias: bool = True, ): super().__init__() self.in_channels = in_channels self.num_iters = num_iters self.gamma = gamma self.epsilon = epsilon self.act = activation_resolver(act, **(act_kwargs or {})) if phi is None: phi = GCNConv(in_channels, in_channels, bias=False) self.W = Parameter(torch.empty(in_channels, in_channels)) self.register_buffer('eye', torch.eye(in_channels)) self.phi = phi if bias: self.bias = Parameter(torch.empty(in_channels)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" torch.nn.init.kaiming_uniform_(self.W, a=math.sqrt(5)) self.phi.reset_parameters() zeros(self.bias)
[docs] def forward(self, x: Tensor, edge_index: Adj, *args, **kwargs) -> Tensor: r"""Runs the forward pass of the module.""" antisymmetric_W = self.W - self.W.t() - self.gamma * self.eye for _ in range(self.num_iters): h = self.phi(x, edge_index, *args, **kwargs) h = x @ antisymmetric_W.t() + h if self.bias is not None: h += self.bias if self.act is not None: h = self.act(h) x = x + self.epsilon * h return x
def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'{self.in_channels}, ' f'phi={self.phi}, ' f'num_iters={self.num_iters}, ' f'epsilon={self.epsilon}, ' f'gamma={self.gamma})')