from typing import List, Optional, Union
from torch import Tensor
from torch.nn import (
BatchNorm1d,
Dropout,
InstanceNorm1d,
LayerNorm,
ReLU,
Sequential,
)
from torch_sparse import SparseTensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.norm import MessageNorm
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from ..inits import reset
class MLP(Sequential):
def __init__(self, channels: List[int], norm: Optional[str] = None,
bias: bool = True, dropout: float = 0.):
m = []
for i in range(1, len(channels)):
m.append(Linear(channels[i - 1], channels[i], bias=bias))
if i < len(channels) - 1:
if norm and norm == 'batch':
m.append(BatchNorm1d(channels[i], affine=True))
elif norm and norm == 'layer':
m.append(LayerNorm(channels[i], elementwise_affine=True))
elif norm and norm == 'instance':
m.append(InstanceNorm1d(channels[i], affine=False))
elif norm:
raise NotImplementedError(
f'Normalization layer "{norm}" not supported.')
m.append(ReLU())
m.append(Dropout(dropout))
super().__init__(*m)
[docs]class GENConv(MessagePassing):
r"""The GENeralized Graph Convolution (GENConv) from the `"DeeperGCN: All
You Need to Train Deeper GCNs" <https://arxiv.org/abs/2006.07739>`_ paper.
Supports SoftMax & PowerMean aggregation. The message construction is:
.. math::
\mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_i +
\mathrm{AGG} \left( \left\{
\mathrm{ReLU} \left( \mathbf{x}_j + \mathbf{e_{ji}} \right) +\epsilon
: j \in \mathcal{N}(i) \right\} \right)
\right)
.. note::
For an example of using :obj:`GENConv`, see
`examples/ogbn_proteins_deepgcn.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
ogbn_proteins_deepgcn.py>`_.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
aggr (str, optional): The aggregation scheme to use (:obj:`"softmax"`,
:obj:`"powermean"`, :obj:`"add"`, :obj:`"mean"`, :obj:`max`).
(default: :obj:`"softmax"`)
t (float, optional): Initial inverse temperature for softmax
aggregation. (default: :obj:`1.0`)
learn_t (bool, optional): If set to :obj:`True`, will learn the value
:obj:`t` for softmax aggregation dynamically.
(default: :obj:`False`)
p (float, optional): Initial power for power mean aggregation.
(default: :obj:`1.0`)
learn_p (bool, optional): If set to :obj:`True`, will learn the value
:obj:`p` for power mean aggregation dynamically.
(default: :obj:`False`)
msg_norm (bool, optional): If set to :obj:`True`, will use message
normalization. (default: :obj:`False`)
learn_msg_scale (bool, optional): If set to :obj:`True`, will learn the
scaling factor of message normalization. (default: :obj:`False`)
norm (str, optional): Norm layer of MLP layers (:obj:`"batch"`,
:obj:`"layer"`, :obj:`"instance"`) (default: :obj:`batch`)
num_layers (int, optional): The number of MLP layers.
(default: :obj:`2`)
eps (float, optional): The epsilon value of the message construction
function. (default: :obj:`1e-7`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GenMessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{t})`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`,
edge attributes :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
"""
def __init__(self, in_channels: int, out_channels: int,
aggr: str = 'softmax', t: float = 1.0, learn_t: bool = False,
p: float = 1.0, learn_p: bool = False, msg_norm: bool = False,
learn_msg_scale: bool = False, norm: str = 'batch',
num_layers: int = 2, eps: float = 1e-7, **kwargs):
# Backward compatibility:
aggr = 'softmax' if aggr == 'softmax_sg' else aggr
aggr = 'powermean' if aggr == 'power' else aggr
aggr_kwargs = {}
if aggr == 'softmax':
aggr_kwargs = dict(t=t, learn=learn_t)
elif aggr == 'powermean':
aggr_kwargs = dict(p=p, learn=learn_p)
super().__init__(aggr=aggr, aggr_kwargs=aggr_kwargs, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.eps = eps
channels = [in_channels]
for i in range(num_layers - 1):
channels.append(in_channels * 2)
channels.append(out_channels)
self.mlp = MLP(channels, norm=norm)
self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None
[docs] def reset_parameters(self):
reset(self.mlp)
self.aggr_module.reset_parameters()
if self.msg_norm is not None:
self.msg_norm.reset_parameters()
[docs] def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None) -> Tensor:
""""""
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
# Node and edge feature dimensionalites need to match.
if isinstance(edge_index, Tensor):
if edge_attr is not None:
assert x[0].size(-1) == edge_attr.size(-1)
elif isinstance(edge_index, SparseTensor):
edge_attr = edge_index.storage.value()
if edge_attr is not None:
assert x[0].size(-1) == edge_attr.size(-1)
# propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
if self.msg_norm is not None:
out = self.msg_norm(x[0], out)
x_r = x[1]
if x_r is not None:
out += x_r
return self.mlp(out)
def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
msg = x_j if edge_attr is None else x_j + edge_attr
return msg.relu() + self.eps
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, aggr={self.aggr})')