Source code for torch_geometric.nn.conv.graph_conv

import torch
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing

from ..inits import uniform

[docs]class GraphConv(MessagePassing): r"""The graph neural network operator from the `"Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks" <>`_ paper .. math:: \mathbf{x}^{\prime}_i = \mathbf{\Theta}_1 \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta}_2 \mathbf{x}_j. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. aggr (string, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_channels, out_channels, aggr='add', bias=True, **kwargs): super(GraphConv, self).__init__(aggr=aggr, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.weight = Parameter(torch.Tensor(in_channels, out_channels)) self.lin = torch.nn.Linear(in_channels, out_channels, bias=bias) self.reset_parameters()
[docs] def reset_parameters(self): uniform(self.in_channels, self.weight) self.lin.reset_parameters()
[docs] def forward(self, x, edge_index, edge_weight=None, size=None): """""" h = torch.matmul(x, self.weight) return self.propagate(edge_index, size=size, x=x, h=h, edge_weight=edge_weight)
def message(self, h_j, edge_weight): return h_j if edge_weight is None else edge_weight.view(-1, 1) * h_j def update(self, aggr_out, x): return aggr_out + self.lin(x) def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)