from torch_geometric.typing import Adj, OptTensor
import torch
from torch import Tensor
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.conv.gcn_conv import gcn_norm
[docs]class TAGConv(MessagePassing):
r"""The topology adaptive graph convolutional networks operator from the
`"Topology Adaptive Graph Convolutional Networks"
<https://arxiv.org/abs/1710.10370>`_ paper
.. math::
\mathbf{X}^{\prime} = \sum_{k=0}^K \left( \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2} \right)^k \mathbf{X} \mathbf{\Theta}_{k},
where :math:`\mathbf{A}` denotes the adjacency matrix and
:math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix.
The adjacency matrix can include other values than :obj:`1` representing
edge weights via the optional :obj:`edge_weight` tensor.
Args:
in_channels (int): Size of each input sample, or :obj:`-1` to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
K (int, optional): Number of hops :math:`K`. (default: :obj:`3`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
normalize (bool, optional): Whether to apply symmetric normalization.
(default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels: int, out_channels: int, K: int = 3,
bias: bool = True, normalize: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(TAGConv, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.K = K
self.normalize = normalize
self.lins = torch.nn.ModuleList(
[Linear(in_channels, out_channels) for _ in range(K + 1)])
self.reset_parameters()
[docs] def reset_parameters(self):
for lin in self.lins:
lin.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
""""""
if self.normalize:
if isinstance(edge_index, Tensor):
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
improved=False, add_self_loops=False, dtype=x.dtype)
elif isinstance(edge_index, SparseTensor):
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
add_self_loops=False, dtype=x.dtype)
out = self.lins[0](x)
for lin in self.lins[1:]:
# propagate_type: (x: Tensor, edge_weight: OptTensor)
x = self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=None)
out += lin.forward(x)
return out
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
def __repr__(self):
return '{}({}, {}, K={})'.format(self.__class__.__name__,
self.in_channels, self.out_channels,
self.K)