from typing import List, Optional, Union
import torch
from torch import Tensor
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import (
Adj,
OptPairTensor,
OptTensor,
Size,
SparseTensor,
torch_sparse,
)
from torch_geometric.utils import add_self_loops, spmm
[docs]class SimpleConv(MessagePassing):
r"""A simple message passing operator that performs (non-trainable)
propagation.
.. math::
\mathbf{x}^{\prime}_i = \bigoplus_{j \in \mathcal{N(i)}} e_{ji} \cdot
\mathbf{x}_j
where :math:`\bigoplus` defines a custom aggregation scheme.
Args:
aggr (str or [str] or Aggregation, optional): The aggregation scheme
to use, *e.g.*, :obj:`"add"`, :obj:`"sum"` :obj:`"mean"`,
:obj:`"min"`, :obj:`"max"` or :obj:`"mul"`.
In addition, can be any
:class:`~torch_geometric.nn.aggr.Aggregation` module (or any string
that automatically resolves to it). (default: :obj:`"sum"`)
combine_root (str, optional): Specifies whether or how to combine the
central node representation (one of :obj:`"sum"`, :obj:`"cat"`,
:obj:`"self_loop"`, :obj:`None`). (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **inputs:**
node features :math:`(|\mathcal{V}|, F)` or
:math:`((|\mathcal{V_s}|, F), (|\mathcal{V_t}|, *))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`
- **outputs:** node features :math:`(|\mathcal{V}|, F)` or
:math:`(|\mathcal{V_t}|, F)` if bipartite
"""
def __init__(
self,
aggr: Optional[Union[str, List[str], Aggregation]] = "sum",
combine_root: Optional[str] = None,
**kwargs,
):
if combine_root not in ['sum', 'cat', 'self_loop', None]:
raise ValueError(f"Received invalid value for 'combine_root' "
f"(got '{combine_root}')")
super().__init__(aggr, **kwargs)
self.combine_root = combine_root
[docs] def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_weight: OptTensor = None, size: Size = None) -> Tensor:
if self.combine_root is not None:
if self.combine_root == 'self_loop':
if not isinstance(x, Tensor) or (size is not None
and size[0] != size[1]):
raise ValueError("Cannot use `combine_root='self_loop'` "
"for bipartite message passing")
if isinstance(edge_index, Tensor):
edge_index, edge_weight = add_self_loops(
edge_index, edge_weight, num_nodes=x.size(0))
elif isinstance(edge_index, SparseTensor):
edge_index = torch_sparse.set_diag(edge_index)
if isinstance(x, Tensor):
x = (x, x)
# propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=size)
x_dst = x[1]
if x_dst is not None and self.combine_root is not None:
if self.combine_root == 'sum':
out = out + x_dst
elif self.combine_root == 'cat':
out = torch.cat([x_dst, out], dim=-1)
return out
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:
assert isinstance(self.aggr, str)
return spmm(adj_t, x[0], reduce=self.aggr)