Source code for torch_geometric.nn.conv.signed_conv

from typing import Union

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, PairTensor, SparseTensor
from torch_geometric.utils import spmm

[docs]class SignedConv(MessagePassing):
r"""The signed graph convolutional operator from the "Signed Graph
Convolutional Network" <https://arxiv.org/abs/1808.06354>_ paper.

.. math::
\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})}
\left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)}
\mathbf{x}_w , \mathbf{x}_v \right]

\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})}
\left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)}
\mathbf{x}_w , \mathbf{x}_v \right]

if :obj:first_aggr is set to :obj:True, and

.. math::
\mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})}
\left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)}
\mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|}
\sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})},
\mathbf{x}_v^{(\textrm{pos})} \right]

\mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})}
\left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)}
\mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|}
\sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})},
\mathbf{x}_v^{(\textrm{neg})} \right]

otherwise.
In case :obj:first_aggr is :obj:False, the layer expects :obj:x to be
a tensor where :obj:x[:, :in_channels] denotes the positive node features
:math:\mathbf{X}^{(\textrm{pos})} and :obj:x[:, in_channels:] denotes
the negative node features :math:\mathbf{X}^{(\textrm{neg})}.

Args:
in_channels (int): Size of each input sample, or :obj:-1 to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
first_aggr (bool): Denotes which aggregation formula to use.
bias (bool, optional): If set to :obj:False, the layer will not learn
an additive bias. (default: :obj:True)
**kwargs (optional): Additional arguments of
:class:torch_geometric.nn.conv.MessagePassing.

Shapes:
- **input:**
node features :math:(|\mathcal{V}|, F_{in}) or
:math:((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))
if bipartite,
positive edge indices :math:(2, |\mathcal{E}^{(+)}|),
negative edge indices :math:(2, |\mathcal{E}^{(-)}|)
- **outputs:** node features :math:(|\mathcal{V}|, F_{out}) or
:math:(|\mathcal{V_t}|, F_{out}) if bipartite
"""
def __init__(self, in_channels: int, out_channels: int, first_aggr: bool,
bias: bool = True, **kwargs):

kwargs.setdefault('aggr', 'mean')
super().__init__(**kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.first_aggr = first_aggr

if first_aggr:
self.lin_pos_l = Linear(in_channels, out_channels, False)
self.lin_pos_r = Linear(in_channels, out_channels, bias)
self.lin_neg_l = Linear(in_channels, out_channels, False)
self.lin_neg_r = Linear(in_channels, out_channels, bias)
else:
self.lin_pos_l = Linear(2 * in_channels, out_channels, False)
self.lin_pos_r = Linear(in_channels, out_channels, bias)
self.lin_neg_l = Linear(2 * in_channels, out_channels, False)
self.lin_neg_r = Linear(in_channels, out_channels, bias)

self.reset_parameters()

[docs]    def reset_parameters(self):
super().reset_parameters()
self.lin_pos_l.reset_parameters()
self.lin_pos_r.reset_parameters()
self.lin_neg_l.reset_parameters()
self.lin_neg_r.reset_parameters()

[docs]    def forward(
self,
x: Union[Tensor, PairTensor],
):

if isinstance(x, Tensor):
x = (x, x)

# propagate_type: (x: PairTensor)
if self.first_aggr:

out_pos = self.propagate(pos_edge_index, x=x)
out_pos = self.lin_pos_l(out_pos)
out_pos = out_pos + self.lin_pos_r(x[1])

out_neg = self.propagate(neg_edge_index, x=x)
out_neg = self.lin_neg_l(out_neg)
out_neg = out_neg + self.lin_neg_r(x[1])

else:
F_in = self.in_channels

out_pos1 = self.propagate(pos_edge_index,
x=(x[0][..., :F_in], x[1][..., :F_in]))
out_pos2 = self.propagate(neg_edge_index,
x=(x[0][..., F_in:], x[1][..., F_in:]))
out_pos = torch.cat([out_pos1, out_pos2], dim=-1)
out_pos = self.lin_pos_l(out_pos)
out_pos = out_pos + self.lin_pos_r(x[1][..., :F_in])

out_neg1 = self.propagate(pos_edge_index,
x=(x[0][..., F_in:], x[1][..., F_in:]))
out_neg2 = self.propagate(neg_edge_index,
x=(x[0][..., :F_in], x[1][..., :F_in]))
out_neg = torch.cat([out_neg1, out_neg2], dim=-1)
out_neg = self.lin_neg_l(out_neg)
out_neg = out_neg + self.lin_neg_r(x[1][..., F_in:])