Source code for torch_geometric.nn.conv.signed_conv

from typing import Union
from torch_geometric.typing import PairTensor, Adj

import torch
from torch import Tensor
from torch.nn import Linear
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing


[docs]class SignedConv(MessagePassing): r"""The signed graph convolutional operator from the `"Signed Graph Convolutional Network" <https://arxiv.org/abs/1808.06354>`_ paper .. math:: \mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w , \mathbf{x}_v \right] \mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})} \left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w , \mathbf{x}_v \right] if :obj:`first_aggr` is set to :obj:`True`, and .. math:: \mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})}, \mathbf{x}_v^{(\textrm{pos})} \right] \mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})} \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} \mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})}, \mathbf{x}_v^{(\textrm{neg})} \right] otherwise. In case :obj:`first_aggr` is :obj:`False`, the layer expects :obj:`x` to be a tensor where :obj:`x[:, :in_channels]` denotes the positive node features :math:`\mathbf{X}^{(\textrm{pos})}` and :obj:`x[:, in_channels:]` denotes the negative node features :math:`\mathbf{X}^{(\textrm{neg})}`. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. first_aggr (bool): Denotes which aggregation formula to use. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_channels: int, out_channels: int, first_aggr: bool, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'mean') super(SignedConv, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.first_aggr = first_aggr if first_aggr: self.lin_pos_l = Linear(in_channels, out_channels, False) self.lin_pos_r = Linear(in_channels, out_channels, bias) self.lin_neg_l = Linear(in_channels, out_channels, False) self.lin_neg_r = Linear(in_channels, out_channels, bias) else: self.lin_pos_l = Linear(2 * in_channels, out_channels, False) self.lin_pos_r = Linear(in_channels, out_channels, bias) self.lin_neg_l = Linear(2 * in_channels, out_channels, False) self.lin_neg_r = Linear(in_channels, out_channels, bias) self.reset_parameters()
[docs] def reset_parameters(self): self.lin_pos_l.reset_parameters() self.lin_pos_r.reset_parameters() self.lin_neg_l.reset_parameters() self.lin_neg_r.reset_parameters()
[docs] def forward(self, x: Union[Tensor, PairTensor], pos_edge_index: Adj, neg_edge_index: Adj): """""" if isinstance(x, Tensor): x: PairTensor = (x, x) # propagate_type: (x: PairTensor) if self.first_aggr: out_pos = self.propagate(pos_edge_index, x=x, size=None) out_pos = self.lin_pos_l(out_pos) out_pos += self.lin_pos_r(x[1]) out_neg = self.propagate(neg_edge_index, x=x, size=None) out_neg = self.lin_neg_l(out_neg) out_neg += self.lin_neg_r(x[1]) return torch.cat([out_pos, out_neg], dim=-1) else: F_in = self.in_channels out_pos1 = self.propagate(pos_edge_index, size=None, x=(x[0][..., :F_in], x[1][..., :F_in])) out_pos2 = self.propagate(neg_edge_index, size=None, x=(x[0][..., F_in:], x[1][..., F_in:])) out_pos = torch.cat([out_pos1, out_pos2], dim=-1) out_pos = self.lin_pos_l(out_pos) out_pos += self.lin_pos_r(x[1][..., :F_in]) out_neg1 = self.propagate(pos_edge_index, size=None, x=(x[0][..., F_in:], x[1][..., F_in:])) out_neg2 = self.propagate(neg_edge_index, size=None, x=(x[0][..., :F_in], x[1][..., :F_in])) out_neg = torch.cat([out_neg1, out_neg2], dim=-1) out_neg = self.lin_neg_l(out_neg) out_neg += self.lin_neg_r(x[1][..., F_in:]) return torch.cat([out_pos, out_neg], dim=-1)
def message(self, x_j: Tensor) -> Tensor: return x_j def message_and_aggregate(self, adj_t: SparseTensor, x: PairTensor) -> Tensor: adj_t = adj_t.set_value(None, layout=None) return matmul(adj_t, x[0], reduce=self.aggr) def __repr__(self): return '{}({}, {}, first_aggr={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.first_aggr)