from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import ModuleList, Sequential
from import DataLoader

from torch_geometric.nn.aggr import DegreeScalerAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree

[docs]class PNAConv(MessagePassing): r"""The Principal Neighbourhood Aggregation graph convolution operator from the `"Principal Neighbourhood Aggregation for Graph Nets" <>`_ paper. .. math:: \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus} h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right) \right) with .. math:: \bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}}, where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` denote MLPs. .. note:: For an example of using :obj:`PNAConv`, see `examples/ < examples/>`_. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. aggregators (List[str]): Set of aggregation function identifiers, namely :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"var"` and :obj:`"std"`. scalers (List[str]): Set of scaling function identifiers, namely :obj:`"identity"`, :obj:`"amplification"`, :obj:`"attenuation"`, :obj:`"linear"` and :obj:`"inverse_linear"`. deg (torch.Tensor): Histogram of in-degrees of nodes in the training set, used by scalers to normalize. edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default :obj:`None`) towers (int, optional): Number of towers (default: :obj:`1`). pre_layers (int, optional): Number of transformation layers before aggregation (default: :obj:`1`). post_layers (int, optional): Number of transformation layers after aggregation (default: :obj:`1`). divide_input (bool, optional): Whether the input features should be split between towers or not (default: :obj:`False`). act (str or callable, optional): Pre- and post-layer activation function to use. (default: :obj:`"relu"`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) train_norm (bool, optional): Whether normalization parameters are trainable. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})`, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` """ def __init__( self, in_channels: int, out_channels: int, aggregators: List[str], scalers: List[str], deg: Tensor, edge_dim: Optional[int] = None, towers: int = 1, pre_layers: int = 1, post_layers: int = 1, divide_input: bool = False, act: Union[str, Callable, None] = "relu", act_kwargs: Optional[Dict[str, Any]] = None, train_norm: bool = False, **kwargs, ): aggr = DegreeScalerAggregation(aggregators, scalers, deg, train_norm) super().__init__(aggr=aggr, node_dim=0, **kwargs) if divide_input: assert in_channels % towers == 0 assert out_channels % towers == 0 self.in_channels = in_channels self.out_channels = out_channels self.edge_dim = edge_dim self.towers = towers self.divide_input = divide_input self.F_in = in_channels // towers if divide_input else in_channels self.F_out = self.out_channels // towers if self.edge_dim is not None: self.edge_encoder = Linear(edge_dim, self.F_in) self.pre_nns = ModuleList() self.post_nns = ModuleList() for _ in range(towers): modules = [Linear((3 if edge_dim else 2) * self.F_in, self.F_in)] for _ in range(pre_layers - 1): modules += [activation_resolver(act, **(act_kwargs or {}))] modules += [Linear(self.F_in, self.F_in)] self.pre_nns.append(Sequential(*modules)) in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in modules = [Linear(in_channels, self.F_out)] for _ in range(post_layers - 1): modules += [activation_resolver(act, **(act_kwargs or {}))] modules += [Linear(self.F_out, self.F_out)] self.post_nns.append(Sequential(*modules)) self.lin = Linear(out_channels, out_channels) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() if self.edge_dim is not None: self.edge_encoder.reset_parameters() for nn in self.pre_nns: reset(nn) for nn in self.post_nns: reset(nn) self.lin.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None) -> Tensor: if self.divide_input: x = x.view(-1, self.towers, self.F_in) else: x = x.view(-1, 1, self.F_in).repeat(1, self.towers, 1) # propagate_type: (x: Tensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr) out =[x, out], dim=-1) outs = [nn(out[:, i]) for i, nn in enumerate(self.post_nns)] out =, dim=1) return self.lin(out)
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor: h: Tensor = x_i # Dummy. if edge_attr is not None: edge_attr = self.edge_encoder(edge_attr) edge_attr = edge_attr.view(-1, 1, self.F_in) edge_attr = edge_attr.repeat(1, self.towers, 1) h =[x_i, x_j, edge_attr], dim=-1) else: h =[x_i, x_j], dim=-1) hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] return torch.stack(hs, dim=1) def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, towers={self.towers}, ' f'edge_dim={self.edge_dim})')
[docs] @staticmethod def get_degree_histogram(loader: DataLoader) -> Tensor: r"""Returns the degree histogram to be used as input for the :obj:`deg` argument in :class:`PNAConv`. """ deg_histogram = torch.zeros(1, dtype=torch.long) for data in loader: deg = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) deg_bincount = torch.bincount(deg, minlength=deg_histogram.numel()) deg_histogram = if deg_bincount.numel() > deg_histogram.numel(): deg_bincount[:deg_histogram.size(0)] += deg_histogram deg_histogram = deg_bincount else: assert deg_bincount.numel() == deg_histogram.numel() deg_histogram += deg_bincount return deg_histogram