Source code for torch_geometric.nn.conv.rgcn_conv

import torch
from torch.nn import Parameter as Param
from torch_geometric.nn.conv import MessagePassing

from ..inits import uniform


[docs]class RGCNConv(MessagePassing): r"""The relational graph convolutional operator from the `"Modeling Relational Data with Graph Convolutional Networks" <https://arxiv.org/abs/1703.06103>`_ paper .. math:: \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j, where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types. Edge type needs to be a one-dimensional :obj:`torch.long` tensor which stores a relation identifier :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. num_relations (int): Number of relations. num_bases (int): Number of bases used for basis-decomposition. root_weight (bool, optional): If set to :obj:`False`, the layer will not add transformed root node features to the output. (default: :obj:`True`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__(self, in_channels, out_channels, num_relations, num_bases, root_weight=True, bias=True, **kwargs): super(RGCNConv, self).__init__(aggr='add', **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations self.num_bases = num_bases self.basis = Param(torch.Tensor(num_bases, in_channels, out_channels)) self.att = Param(torch.Tensor(num_relations, num_bases)) if root_weight: self.root = Param(torch.Tensor(in_channels, out_channels)) else: self.register_parameter('root', None) if bias: self.bias = Param(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): size = self.num_bases * self.in_channels uniform(size, self.basis) uniform(size, self.att) uniform(size, self.root) uniform(size, self.bias)
[docs] def forward(self, x, edge_index, edge_type, edge_norm=None, size=None): """""" return self.propagate(edge_index, size=size, x=x, edge_type=edge_type, edge_norm=edge_norm)
def message(self, x_j, edge_index_j, edge_type, edge_norm): w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) # If no node features are given, we implement a simple embedding # loopkup based on the target node index and its edge type. if x_j is None: w = w.view(-1, self.out_channels) index = edge_type * self.in_channels + edge_index_j out = torch.index_select(w, 0, index) else: w = w.view(self.num_relations, self.in_channels, self.out_channels) w = torch.index_select(w, 0, edge_type) out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2) return out if edge_norm is None else out * edge_norm.view(-1, 1) def update(self, aggr_out, x): if self.root is not None: if x is None: aggr_out = aggr_out + self.root else: aggr_out = aggr_out + torch.matmul(x, self.root) if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out def __repr__(self): return '{}({}, {}, num_relations={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.num_relations)