Source code for torch_geometric.nn.conv.tag_conv

from torch.nn import Linear
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
import torch


[docs]class TAGConv(MessagePassing): r"""The topology adaptive graph convolutional networks operator from the `"Topology Adaptive Graph Convolutional Networks" <https://arxiv.org/abs/1710.10370>`_ paper .. math:: \mathbf{X}^{\prime} = \sum_{k=0}^K \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\mathbf{X} \mathbf{\Theta}_{k}, where :math:`\mathbf{A}` denotes the adjacency matrix and :math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. K (int, optional): Number of hops :math:`K`. (default: :obj:`3`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_channels, out_channels, K=3, bias=True, **kwargs): super(TAGConv, self).__init__(aggr='add', **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.K = K self.lin = Linear(in_channels * (self.K + 1), out_channels, bias=bias) self.reset_parameters()
[docs] def reset_parameters(self): self.lin.reset_parameters()
[docs] @staticmethod def norm(edge_index, num_nodes, edge_weight=None, dtype=None): if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) row, col = edge_index deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
[docs] def forward(self, x, edge_index, edge_weight=None): """""" edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, dtype=x.dtype) xs = [x] for k in range(self.K): xs.append(self.propagate(edge_index, x=xs[-1], norm=norm)) return self.lin(torch.cat(xs, dim=-1))
def message(self, x_j, norm): return norm.view(-1, 1) * x_j def __repr__(self): return '{}({}, {}, K={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.K)