from typing import Union
from torch_geometric.typing import PairTensor, Adj
import torch
from torch import Tensor
from torch_geometric.nn.dense.linear import Linear
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing
[docs]class SignedConv(MessagePassing):
r"""The signed graph convolutional operator from the `"Signed Graph
Convolutional Network" <https://arxiv.org/abs/1808.06354>`_ paper
.. math::
\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})}
\left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)}
\mathbf{x}_w , \mathbf{x}_v \right]
\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})}
\left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)}
\mathbf{x}_w , \mathbf{x}_v \right]
if :obj:`first_aggr` is set to :obj:`True`, and
.. math::
\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})}
\left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)}
\mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|}
\sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})},
\mathbf{x}_v^{(\textrm{pos})} \right]
\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})}
\left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)}
\mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|}
\sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})},
\mathbf{x}_v^{(\textrm{neg})} \right]
otherwise.
In case :obj:`first_aggr` is :obj:`False`, the layer expects :obj:`x` to be
a tensor where :obj:`x[:, :in_channels]` denotes the positive node features
:math:`\mathbf{X}^{(\textrm{pos})}` and :obj:`x[:, in_channels:]` denotes
the negative node features :math:`\mathbf{X}^{(\textrm{neg})}`.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
first_aggr (bool): Denotes which aggregation formula to use.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels: int, out_channels: int, first_aggr: bool,
bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'mean')
super(SignedConv, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.first_aggr = first_aggr
if first_aggr:
self.lin_pos_l = Linear(in_channels, out_channels, False)
self.lin_pos_r = Linear(in_channels, out_channels, bias)
self.lin_neg_l = Linear(in_channels, out_channels, False)
self.lin_neg_r = Linear(in_channels, out_channels, bias)
else:
self.lin_pos_l = Linear(2 * in_channels, out_channels, False)
self.lin_pos_r = Linear(in_channels, out_channels, bias)
self.lin_neg_l = Linear(2 * in_channels, out_channels, False)
self.lin_neg_r = Linear(in_channels, out_channels, bias)
self.reset_parameters()
[docs] def reset_parameters(self):
self.lin_pos_l.reset_parameters()
self.lin_pos_r.reset_parameters()
self.lin_neg_l.reset_parameters()
self.lin_neg_r.reset_parameters()
[docs] def forward(self, x: Union[Tensor, PairTensor], pos_edge_index: Adj,
neg_edge_index: Adj):
""""""
if isinstance(x, Tensor):
x: PairTensor = (x, x)
# propagate_type: (x: PairTensor)
if self.first_aggr:
out_pos = self.propagate(pos_edge_index, x=x, size=None)
out_pos = self.lin_pos_l(out_pos)
out_pos += self.lin_pos_r(x[1])
out_neg = self.propagate(neg_edge_index, x=x, size=None)
out_neg = self.lin_neg_l(out_neg)
out_neg += self.lin_neg_r(x[1])
return torch.cat([out_pos, out_neg], dim=-1)
else:
F_in = self.in_channels
out_pos1 = self.propagate(pos_edge_index, size=None,
x=(x[0][..., :F_in], x[1][..., :F_in]))
out_pos2 = self.propagate(neg_edge_index, size=None,
x=(x[0][..., F_in:], x[1][..., F_in:]))
out_pos = torch.cat([out_pos1, out_pos2], dim=-1)
out_pos = self.lin_pos_l(out_pos)
out_pos += self.lin_pos_r(x[1][..., :F_in])
out_neg1 = self.propagate(pos_edge_index, size=None,
x=(x[0][..., F_in:], x[1][..., F_in:]))
out_neg2 = self.propagate(neg_edge_index, size=None,
x=(x[0][..., :F_in], x[1][..., :F_in]))
out_neg = torch.cat([out_neg1, out_neg2], dim=-1)
out_neg = self.lin_neg_l(out_neg)
out_neg += self.lin_neg_r(x[1][..., F_in:])
return torch.cat([out_pos, out_neg], dim=-1)
def message(self, x_j: Tensor) -> Tensor:
return x_j
def message_and_aggregate(self, adj_t: SparseTensor,
x: PairTensor) -> Tensor:
adj_t = adj_t.set_value(None, layout=None)
return matmul(adj_t, x[0], reduce=self.aggr)
def __repr__(self):
return '{}({}, {}, first_aggr={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels,
self.first_aggr)