Source code for torch_geometric.nn.conv.graph_conv

from typing import Final, Tuple, Union

import torch
from torch import Tensor

from torch_geometric import EdgeIndex
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_geometric.utils import spmm


[docs]class GraphConv(MessagePassing): r"""The graph neural network operator from the `"Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks" <https://arxiv.org/abs/1810.02244>`_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target node :obj:`i` (default: :obj:`1`) Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_channels (int): Size of each output sample. aggr (str, optional): The aggregation scheme to use (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). (default: :obj:`"add"`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` if bipartite, edge indices :math:`(2, |\mathcal{E}|)`, edge weights :math:`(|\mathcal{E}|)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V}_t|, F_{out})` if bipartite """ SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = True def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, aggr: str = 'add', bias: bool = True, **kwargs, ): super().__init__(aggr=aggr, **kwargs) self.in_channels = in_channels self.out_channels = out_channels if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.lin_rel = Linear(in_channels[0], out_channels, bias=bias) self.lin_root = Linear(in_channels[1], out_channels, bias=False) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() self.lin_rel.reset_parameters() self.lin_root.reset_parameters()
[docs] def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight: OptTensor = None, size: Size = None) -> Tensor: if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) out = self.lin_rel(out) x_r = x[1] if x_r is not None: out = out + self.lin_root(x_r) return out
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate( self, edge_index: Adj, x: OptPairTensor, edge_weight: OptTensor, ) -> Tensor: if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex): return edge_index.matmul( other=x[0], input_value=edge_weight, reduce=self.aggr, transpose=True, ) return spmm(edge_index, x[0], reduce=self.aggr)