Source code for torch_geometric.nn.conv.signed_conv

from typing import Union

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, PairTensor, SparseTensor
from torch_geometric.utils import spmm


[docs]class SignedConv(MessagePassing): r"""The signed graph convolutional operator from the `"Signed Graph Convolutional Network" <https://arxiv.org/abs/1808.06354>`_ paper. .. math:: \mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w , \mathbf{x}_v \right] \mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})} \left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w , \mathbf{x}_v \right] if :obj:`first_aggr` is set to :obj:`True`, and .. math:: \mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})}, \mathbf{x}_v^{(\textrm{pos})} \right] \mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})}, \mathbf{x}_v^{(\textrm{neg})} \right] otherwise. In case :obj:`first_aggr` is :obj:`False`, the layer expects :obj:`x` to be a tensor where :obj:`x[:, :in_channels]` denotes the positive node features :math:`\mathbf{X}^{(\textrm{pos})}` and :obj:`x[:, in_channels:]` denotes the negative node features :math:`\mathbf{X}^{(\textrm{neg})}`. Args: in_channels (int): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. first_aggr (bool): Denotes which aggregation formula to use. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. Shapes: - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))` if bipartite, positive edge indices :math:`(2, |\mathcal{E}^{(+)}|)`, negative edge indices :math:`(2, |\mathcal{E}^{(-)}|)` - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or :math:`(|\mathcal{V_t}|, F_{out})` if bipartite """ def __init__(self, in_channels: int, out_channels: int, first_aggr: bool, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'mean') super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.first_aggr = first_aggr if first_aggr: self.lin_pos_l = Linear(in_channels, out_channels, False) self.lin_pos_r = Linear(in_channels, out_channels, bias) self.lin_neg_l = Linear(in_channels, out_channels, False) self.lin_neg_r = Linear(in_channels, out_channels, bias) else: self.lin_pos_l = Linear(2 * in_channels, out_channels, False) self.lin_pos_r = Linear(in_channels, out_channels, bias) self.lin_neg_l = Linear(2 * in_channels, out_channels, False) self.lin_neg_r = Linear(in_channels, out_channels, bias) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() self.lin_pos_l.reset_parameters() self.lin_pos_r.reset_parameters() self.lin_neg_l.reset_parameters() self.lin_neg_r.reset_parameters()
[docs] def forward( self, x: Union[Tensor, PairTensor], pos_edge_index: Adj, neg_edge_index: Adj, ): if isinstance(x, Tensor): x = (x, x) # propagate_type: (x: PairTensor) if self.first_aggr: out_pos = self.propagate(pos_edge_index, x=x) out_pos = self.lin_pos_l(out_pos) out_pos = out_pos + self.lin_pos_r(x[1]) out_neg = self.propagate(neg_edge_index, x=x) out_neg = self.lin_neg_l(out_neg) out_neg = out_neg + self.lin_neg_r(x[1]) return torch.cat([out_pos, out_neg], dim=-1) else: F_in = self.in_channels out_pos1 = self.propagate(pos_edge_index, x=(x[0][..., :F_in], x[1][..., :F_in])) out_pos2 = self.propagate(neg_edge_index, x=(x[0][..., F_in:], x[1][..., F_in:])) out_pos = torch.cat([out_pos1, out_pos2], dim=-1) out_pos = self.lin_pos_l(out_pos) out_pos = out_pos + self.lin_pos_r(x[1][..., :F_in]) out_neg1 = self.propagate(pos_edge_index, x=(x[0][..., F_in:], x[1][..., F_in:])) out_neg2 = self.propagate(neg_edge_index, x=(x[0][..., :F_in], x[1][..., :F_in])) out_neg = torch.cat([out_neg1, out_neg2], dim=-1) out_neg = self.lin_neg_l(out_neg) out_neg = out_neg + self.lin_neg_r(x[1][..., F_in:]) return torch.cat([out_pos, out_neg], dim=-1)
def message(self, x_j: Tensor) -> Tensor: return x_j def message_and_aggregate(self, adj_t: Adj, x: PairTensor) -> Tensor: if isinstance(adj_t, SparseTensor): adj_t = adj_t.set_value(None, layout=None) return spmm(adj_t, x[0], reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, first_aggr={self.first_aggr})')