Source code for torch_geometric.nn.models.linkx

from torch_geometric.typing import OptTensor, Adj

import math

import torch
from torch import Tensor
from torch.nn import BatchNorm1d, Parameter
from torch_sparse import SparseTensor, matmul

from torch_geometric.nn import inits
from torch_geometric.nn.models import MLP
from torch_geometric.nn.conv import MessagePassing


class SparseLinear(MessagePassing):
    def __init__(self, in_channels: int, out_channels: int, bias: bool = True):
        super().__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        inits.kaiming_uniform(self.weight, fan=self.in_channels,
                              a=math.sqrt(5))
        inits.uniform(self.in_channels, self.bias)

    def forward(self, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        # propagate_type: (weight: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, weight=self.weight,
                             edge_weight=edge_weight, size=None)
        if self.bias is not None:
            out += self.bias
        return out

    def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor:
        if edge_weight is None:
            return weight_j
        else:
            return edge_weight.view(-1, 1) * weight_j

    def message_and_aggregate(self, adj_t: SparseTensor,
                              weight: Tensor) -> Tensor:
        return matmul(adj_t, weight, reduce=self.aggr)


[docs]class LINKX(torch.nn.Module): r"""The LINKX model from the `"Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods" <https://arxiv.org/abs/2110.14446>`_ paper .. math:: \mathbf{H}_{\mathbf{A}} &= \textrm{MLP}_{\mathbf{A}}(\mathbf{A}) \mathbf{H}_{\mathbf{X}} &= \textrm{MLP}_{\mathbf{X}}(\mathbf{X}) \mathbf{Y} &= \textrm{MLP}_{f} \left( \sigma \left( \mathbf{W} [\mathbf{H}_{\mathbf{A}}, \mathbf{H}_{\mathbf{X}}] + \mathbf{H}_{\mathbf{A}} + \mathbf{H}_{\mathbf{X}} \right) \right) .. note:: For an example of using LINKX, see `examples/linkx.py <https:// github.com/pyg-team/pytorch_geometric/blob/master/examples/linkx.py>`_. Args: num_nodes (int): The number of nodes in the graph. in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. hidden_channels (int): Size of each hidden sample. out_channels (int): Size of each output sample. num_layers (int): Number of layers of :math:`\textrm{MLP}_{f}`. num_edge_layers (int): Number of layers of :math:`\textrm{MLP}_{\mathbf{A}}`. (default: :obj:`1`) num_node_layers (int): Number of layers of :math:`\textrm{MLP}_{\mathbf{X}}`. (default: :obj:`1`) dropout (float, optional): Dropout probability of each hidden embedding. (default: :obj:`0.`) """ def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, num_edge_layers: int = 1, num_node_layers: int = 1, dropout: float = 0.): super().__init__() self.num_nodes = num_nodes self.in_channels = in_channels self.out_channels = out_channels self.num_edge_layers = num_edge_layers self.edge_lin = SparseLinear(num_nodes, hidden_channels) if self.num_edge_layers > 1: self.edge_norm = BatchNorm1d(hidden_channels) channels = [hidden_channels] * num_edge_layers self.edge_mlp = MLP(channels, dropout=0., relu_first=True) channels = [in_channels] + [hidden_channels] * num_node_layers self.node_mlp = MLP(channels, dropout=0., relu_first=True) self.cat_lin1 = torch.nn.Linear(hidden_channels, hidden_channels) self.cat_lin2 = torch.nn.Linear(hidden_channels, hidden_channels) channels = [hidden_channels] * num_layers + [out_channels] self.final_mlp = MLP(channels, dropout, relu_first=True) self.reset_parameters()
[docs] def reset_parameters(self): self.edge_lin.reset_parameters() if self.num_edge_layers > 1: self.edge_norm.reset_parameters() self.edge_mlp.reset_parameters() self.node_mlp.reset_parameters() self.cat_lin1.reset_parameters() self.cat_lin2.reset_parameters() self.final_mlp.reset_parameters()
[docs] def forward(self, x: OptTensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" out = self.edge_lin(edge_index, edge_weight) if self.num_edge_layers > 1: out = out.relu_() out = self.edge_norm(out) out = self.edge_mlp(out) out = out + self.cat_lin1(out) if x is not None: x = self.node_mlp(x) out += x out += self.cat_lin2(x) return self.final_mlp(out.relu_())
def __repr__(self) -> str: return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, ' f'in_channels={self.in_channels}, ' f'out_channels={self.out_channels})')