import math

import torch
from torch import Tensor
from torch.nn import BatchNorm1d, Parameter

from torch_geometric.nn import inits
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.models import MLP
from torch_geometric.utils import spmm

class SparseLinear(MessagePassing):
def __init__(self, in_channels: int, out_channels: int, bias: bool = True):
self.in_channels = in_channels
self.out_channels = out_channels

self.weight = Parameter(torch.empty(in_channels, out_channels))
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self):
inits.kaiming_uniform(self.weight, fan=self.in_channels,
a=math.sqrt(5))
inits.uniform(self.in_channels, self.bias)

def forward(
self,
edge_weight: OptTensor = None,
) -> Tensor:
# propagate_type: (weight: Tensor, edge_weight: OptTensor)
out = self.propagate(edge_index, weight=self.weight,
edge_weight=edge_weight)

if self.bias is not None:
out = out + self.bias

return out

def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor:
if edge_weight is None:
return weight_j
else:
return edge_weight.view(-1, 1) * weight_j

r"""The LINKX model from the "Large Scale Learning on Non-Homophilous
Graphs: New Benchmarks and Strong Simple Methods"
<https://arxiv.org/abs/2110.14446>_ paper.

.. math::
\mathbf{H}_{\mathbf{A}} &= \textrm{MLP}_{\mathbf{A}}(\mathbf{A})

\mathbf{H}_{\mathbf{X}} &= \textrm{MLP}_{\mathbf{X}}(\mathbf{X})

\mathbf{Y} &= \textrm{MLP}_{f} \left( \sigma \left( \mathbf{W}
[\mathbf{H}_{\mathbf{A}}, \mathbf{H}_{\mathbf{X}}] +
\mathbf{H}_{\mathbf{A}} + \mathbf{H}_{\mathbf{X}} \right) \right)

.. note::

github.com/pyg-team/pytorch_geometric/blob/master/examples/linkx.py>_.

Args:
num_nodes (int): The number of nodes in the graph.
in_channels (int): Size of each input sample, or :obj:-1 to derive
the size from the first input(s) to the forward method.
hidden_channels (int): Size of each hidden sample.
out_channels (int): Size of each output sample.
num_layers (int): Number of layers of :math:\textrm{MLP}_{f}.
num_edge_layers (int, optional): Number of layers of
:math:\textrm{MLP}_{\mathbf{A}}. (default: :obj:1)
num_node_layers (int, optional): Number of layers of
:math:\textrm{MLP}_{\mathbf{X}}. (default: :obj:1)
dropout (float, optional): Dropout probability of each hidden
embedding. (default: :obj:0.0)
"""
def __init__(
self,
num_nodes: int,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
num_edge_layers: int = 1,
num_node_layers: int = 1,
dropout: float = 0.0,
):
super().__init__()

self.num_nodes = num_nodes
self.in_channels = in_channels
self.out_channels = out_channels
self.num_edge_layers = num_edge_layers

self.edge_lin = SparseLinear(num_nodes, hidden_channels)

if self.num_edge_layers > 1:
self.edge_norm = BatchNorm1d(hidden_channels)
channels = [hidden_channels] * num_edge_layers
self.edge_mlp = MLP(channels, dropout=0., act_first=True)
else:
self.edge_norm = None
self.edge_mlp = None

channels = [in_channels] + [hidden_channels] * num_node_layers
self.node_mlp = MLP(channels, dropout=0., act_first=True)

self.cat_lin1 = torch.nn.Linear(hidden_channels, hidden_channels)
self.cat_lin2 = torch.nn.Linear(hidden_channels, hidden_channels)

channels = [hidden_channels] * num_layers + [out_channels]
self.final_mlp = MLP(channels, dropout=dropout, act_first=True)

self.reset_parameters()

[docs]    def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.edge_lin.reset_parameters()
if self.edge_norm is not None:
self.edge_norm.reset_parameters()
if self.edge_mlp is not None:
self.edge_mlp.reset_parameters()
self.node_mlp.reset_parameters()
self.cat_lin1.reset_parameters()
self.cat_lin2.reset_parameters()
self.final_mlp.reset_parameters()

[docs]    def forward(
self,
x: OptTensor,
edge_weight: OptTensor = None,
) -> Tensor:
""""""  # noqa: D419
out = self.edge_lin(edge_index, edge_weight)

if self.edge_norm is not None and self.edge_mlp is not None:
out = out.relu_()
out = self.edge_norm(out)
out = self.edge_mlp(out)

out = out + self.cat_lin1(out)

if x is not None:
x = self.node_mlp(x)
out = out + x
out = out + self.cat_lin2(x)

return self.final_mlp(out.relu_())

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, '
f'in_channels={self.in_channels}, '
f'out_channels={self.out_channels})')