from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch import Tensor
from torch.nn import ModuleList, Sequential
from torch.utils.data import DataLoader
from torch_geometric.nn.aggr import DegreeScalerAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree
[docs]class PNAConv(MessagePassing):
r"""The Principal Neighbourhood Aggregation graph convolution operator
from the `"Principal Neighbourhood Aggregation for Graph Nets"
<https://arxiv.org/abs/2004.05718>`_ paper.
.. math::
\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left(
\mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus}
h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right)
\right)
with
.. math::
\bigoplus = \underbrace{\begin{bmatrix}
1 \\
S(\mathbf{D}, \alpha=1) \\
S(\mathbf{D}, \alpha=-1)
\end{bmatrix} }_{\text{scalers}}
\otimes \underbrace{\begin{bmatrix}
\mu \\
\sigma \\
\max \\
\min
\end{bmatrix}}_{\text{aggregators}},
where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}`
denote MLPs.
.. note::
For an example of using :obj:`PNAConv`, see `examples/pna.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/
examples/pna.py>`_.
Args:
in_channels (int): Size of each input sample, or :obj:`-1` to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
aggregators (List[str]): Set of aggregation function identifiers,
namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"var"` and :obj:`"std"`.
scalers (List[str]): Set of scaling function identifiers, namely
:obj:`"identity"`, :obj:`"amplification"`,
:obj:`"attenuation"`, :obj:`"linear"` and
:obj:`"inverse_linear"`.
deg (torch.Tensor): Histogram of in-degrees of nodes in the training
set, used by scalers to normalize.
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). (default :obj:`None`)
towers (int, optional): Number of towers (default: :obj:`1`).
pre_layers (int, optional): Number of transformation layers before
aggregation (default: :obj:`1`).
post_layers (int, optional): Number of transformation layers after
aggregation (default: :obj:`1`).
divide_input (bool, optional): Whether the input features should
be split between towers or not (default: :obj:`False`).
act (str or callable, optional): Pre- and post-layer activation
function to use. (default: :obj:`"relu"`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
train_norm (bool, optional): Whether normalization parameters
are trainable. (default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
edge indices :math:`(2, |\mathcal{E}|)`,
edge features :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
"""
def __init__(
self,
in_channels: int,
out_channels: int,
aggregators: List[str],
scalers: List[str],
deg: Tensor,
edge_dim: Optional[int] = None,
towers: int = 1,
pre_layers: int = 1,
post_layers: int = 1,
divide_input: bool = False,
act: Union[str, Callable, None] = "relu",
act_kwargs: Optional[Dict[str, Any]] = None,
train_norm: bool = False,
**kwargs,
):
aggr = DegreeScalerAggregation(aggregators, scalers, deg, train_norm)
super().__init__(aggr=aggr, node_dim=0, **kwargs)
if divide_input:
assert in_channels % towers == 0
assert out_channels % towers == 0
self.in_channels = in_channels
self.out_channels = out_channels
self.edge_dim = edge_dim
self.towers = towers
self.divide_input = divide_input
self.F_in = in_channels // towers if divide_input else in_channels
self.F_out = self.out_channels // towers
if self.edge_dim is not None:
self.edge_encoder = Linear(edge_dim, self.F_in)
self.pre_nns = ModuleList()
self.post_nns = ModuleList()
for _ in range(towers):
modules = [Linear((3 if edge_dim else 2) * self.F_in, self.F_in)]
for _ in range(pre_layers - 1):
modules += [activation_resolver(act, **(act_kwargs or {}))]
modules += [Linear(self.F_in, self.F_in)]
self.pre_nns.append(Sequential(*modules))
in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in
modules = [Linear(in_channels, self.F_out)]
for _ in range(post_layers - 1):
modules += [activation_resolver(act, **(act_kwargs or {}))]
modules += [Linear(self.F_out, self.F_out)]
self.post_nns.append(Sequential(*modules))
self.lin = Linear(out_channels, out_channels)
self.reset_parameters()
[docs] def reset_parameters(self):
super().reset_parameters()
if self.edge_dim is not None:
self.edge_encoder.reset_parameters()
for nn in self.pre_nns:
reset(nn)
for nn in self.post_nns:
reset(nn)
self.lin.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Adj,
edge_attr: OptTensor = None) -> Tensor:
if self.divide_input:
x = x.view(-1, self.towers, self.F_in)
else:
x = x.view(-1, 1, self.F_in).repeat(1, self.towers, 1)
# propagate_type: (x: Tensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
out = torch.cat([x, out], dim=-1)
outs = [nn(out[:, i]) for i, nn in enumerate(self.post_nns)]
out = torch.cat(outs, dim=1)
return self.lin(out)
def message(self, x_i: Tensor, x_j: Tensor,
edge_attr: OptTensor) -> Tensor:
h: Tensor = x_i # Dummy.
if edge_attr is not None:
edge_attr = self.edge_encoder(edge_attr)
edge_attr = edge_attr.view(-1, 1, self.F_in)
edge_attr = edge_attr.repeat(1, self.towers, 1)
h = torch.cat([x_i, x_j, edge_attr], dim=-1)
else:
h = torch.cat([x_i, x_j], dim=-1)
hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)]
return torch.stack(hs, dim=1)
def __repr__(self):
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, towers={self.towers}, '
f'edge_dim={self.edge_dim})')
[docs] @staticmethod
def get_degree_histogram(loader: DataLoader) -> Tensor:
r"""Returns the degree histogram to be used as input for the :obj:`deg`
argument in :class:`PNAConv`.
"""
deg_histogram = torch.zeros(1, dtype=torch.long)
for data in loader:
deg = degree(data.edge_index[1], num_nodes=data.num_nodes,
dtype=torch.long)
deg_bincount = torch.bincount(deg, minlength=deg_histogram.numel())
deg_histogram = deg_histogram.to(deg_bincount.device)
if deg_bincount.numel() > deg_histogram.numel():
deg_bincount[:deg_histogram.size(0)] += deg_histogram
deg_histogram = deg_bincount
else:
assert deg_bincount.numel() == deg_histogram.numel()
deg_histogram += deg_bincount
return deg_histogram