from torch_geometric.typing import Adj, OptTensor
import torch
from torch import Tensor
from torch.nn import Linear
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
[docs]class TAGConv(MessagePassing):
r"""The topology adaptive graph convolutional networks operator from the
`"Topology Adaptive Graph Convolutional Networks"
<https://arxiv.org/abs/1710.10370>`_ paper
.. math::
\mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A}^k
\mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k},
where :math:`\mathbf{A}` denotes the adjacency matrix and
:math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
K (int, optional): Number of hops :math:`K`. (default: :obj:`3`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
normalize (bool, optional): Whether to apply symmetric normalization.
(default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels: int, out_channels: int, K: int = 3,
bias: bool = True, normalize: bool = True, **kwargs):
super(TAGConv, self).__init__(aggr='add', **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.K = K
self.normalize = normalize
self.lin = Linear(in_channels * (self.K + 1), out_channels, bias=bias)
self.reset_parameters()
[docs] def reset_parameters(self):
self.lin.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
""""""
if self.normalize:
if isinstance(edge_index, Tensor):
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
improved=False, add_self_loops=False, dtype=x.dtype)
elif isinstance(edge_index, SparseTensor):
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
add_self_loops=False, dtype=x.dtype)
xs = [x]
for k in range(self.K):
# propagate_type: (x: Tensor, edge_weight: OptTensor)
out = self.propagate(edge_index, x=xs[-1], edge_weight=edge_weight,
size=None)
xs.append(out)
return self.lin(torch.cat(xs, dim=-1))
def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return edge_weight.view(-1, 1) * x_j
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
def __repr__(self):
return '{}({}, {}, K={})'.format(self.__class__.__name__,
self.in_channels, self.out_channels,
self.K)