Source code for torch_geometric.nn.conv.gin_conv

from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    Size,
    SparseTensor,
)
from torch_geometric.utils import spmm


[docs]class GINConv(MessagePassing): r"""The graph isomorphism operator from the `"How Powerful are Graph Neural Networks?" <https://arxiv.org/abs/1810.00826>`_ paper. .. math:: \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) or .. math:: \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right), here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP. Args: nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. eps (float, optional): (Initial) :math:`\epsilon`-value. (default: :obj:`0.`) train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` will be a trainable parameter. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)` - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.nn = nn self.initial_eps = eps if train_eps: self.eps = torch.nn.Parameter(torch.empty(1)) else: self.register_buffer('eps', torch.empty(1)) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() reset(self.nn) self.eps.data.fill_(self.initial_eps)
[docs] def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor) out = self.propagate(edge_index, x=x, size=size) x_r = x[1] if x_r is not None: out = out + (1 + self.eps) * x_r return self.nn(out)
def message(self, x_j: Tensor) -> Tensor: return x_j def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: if isinstance(adj_t, SparseTensor): adj_t = adj_t.set_value(None, layout=None) return spmm(adj_t, x[0], reduce=self.aggr) def __repr__(self) -> str: return f'{self.__class__.__name__}(nn={self.nn})'
[docs]class GINEConv(MessagePassing): r"""The modified :class:`GINConv` operator from the `"Strategies for Pre-training Graph Neural Networks" <https://arxiv.org/abs/1905.12265>`_ paper. .. math:: \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathrm{ReLU} ( \mathbf{x}_j + \mathbf{e}_{j,i} ) \right) that is able to incorporate edge features :math:`\mathbf{e}_{j,i}` into the aggregation procedure. Args: nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`, *e.g.*, defined by :class:`torch.nn.Sequential`. eps (float, optional): (Initial) :math:`\epsilon`-value. (default: :obj:`0.`) train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` will be a trainable parameter. (default: :obj:`False`) edge_dim (int, optional): Edge feature dimensionality. If set to :obj:`None`, node and edge feature dimensionality is expected to match. Other-wise, edge features are linearly transformed to match node feature dimensionality. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ def __init__(self, nn: torch.nn.Module, eps: float = 0., train_eps: bool = False, edge_dim: Optional[int] = None, **kwargs): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.nn = nn self.initial_eps = eps if train_eps: self.eps = torch.nn.Parameter(torch.empty(1)) else: self.register_buffer('eps', torch.empty(1)) if edge_dim is not None: if isinstance(self.nn, torch.nn.Sequential): nn = self.nn[0] if hasattr(nn, 'in_features'): in_channels = nn.in_features elif hasattr(nn, 'in_channels'): in_channels = nn.in_channels else: raise ValueError("Could not infer input channels from `nn`.") self.lin = Linear(edge_dim, in_channels) else: self.lin = None self.reset_parameters()
[docs] def reset_parameters(self): reset(self.nn) self.eps.data.fill_(self.initial_eps) if self.lin is not None: self.lin.reset_parameters()
[docs] def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, ) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) x_r = x[1] if x_r is not None: out = out + (1 + self.eps) * x_r return self.nn(out)
def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: if self.lin is None and x_j.size(-1) != edge_attr.size(-1): raise ValueError("Node and edge feature dimensionalities do not " "match. Consider setting the 'edge_dim' " "attribute of 'GINEConv'") if self.lin is not None: edge_attr = self.lin(edge_attr) return (x_j + edge_attr).relu() def __repr__(self) -> str: return f'{self.__class__.__name__}(nn={self.nn})'