from typing import Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Parameter
import torch_geometric.backend
import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.index import index2ptr
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import (
from torch_geometric.utils import index_sort, one_hot, scatter, spmm
def masked_edge_index(edge_index: Adj, edge_mask: Tensor) -> Adj:
if isinstance(edge_index, Tensor):
return edge_index[:, edge_mask]
return torch_sparse.masked_select_nnz(edge_index, edge_mask, layout='coo')
[docs]class RGCNConv(MessagePassing):
r"""The relational graph convolutional operator from the `"Modeling
Relational Data with Graph Convolutional Networks"
<>`_ paper.
.. math::
\mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
\mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
\frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,
where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
stores a relation identifier
:math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.
.. note::
This implementation is as memory-efficient as possible by iterating
over each individual relation type.
Therefore, it may result in low GPU utilization in case the graph has a
large number of relations.
As an alternative approach, :class:`FastRGCNConv` does not iterate over
each individual type, but may consume a large amount of memory to
We advise to check out both implementations to see which one fits your
.. note::
:class:`RGCNConv` can use `dynamic shapes
.html#work_dynamic_shapes>`_, which means that the shape of the interim
tensors can be determined at runtime.
If your device doesn't support dynamic shapes, use
:class:`FastRGCNConv` instead.
in_channels (int or tuple): Size of each input sample. A tuple
corresponds to the sizes of source and target dimensionalities.
In case no input features are given, this argument should
correspond to the number of nodes in your graph.
out_channels (int): Size of each output sample.
num_relations (int): Number of relations.
num_bases (int, optional): If set, this layer will use the
basis-decomposition regularization scheme where :obj:`num_bases`
denotes the number of bases to use. (default: :obj:`None`)
num_blocks (int, optional): If set, this layer will use the
block-diagonal-decomposition regularization scheme where
:obj:`num_blocks` denotes the number of blocks to use.
(default: :obj:`None`)
aggr (str, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
(default: :obj:`"mean"`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add transformed root node features to the output.
(default: :obj:`True`)
is_sorted (bool, optional): If set to :obj:`True`, assumes that
:obj:`edge_index` is sorted by :obj:`edge_type`. This avoids
internal re-sorting of the data and can improve runtime and memory
efficiency. (default: :obj:`False`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
def __init__(
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
num_relations: int,
num_bases: Optional[int] = None,
num_blocks: Optional[int] = None,
aggr: str = 'mean',
root_weight: bool = True,
is_sorted: bool = False,
bias: bool = True,
kwargs.setdefault('aggr', aggr)
super().__init__(node_dim=0, **kwargs)
if num_bases is not None and num_blocks is not None:
raise ValueError('Can not apply both basis-decomposition and '
'block-diagonal-decomposition at the same time.')
self.in_channels = in_channels
self.out_channels = out_channels
self.num_relations = num_relations
self.num_bases = num_bases
self.num_blocks = num_blocks
self.is_sorted = is_sorted
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.in_channels_l = in_channels[0]
self._use_segment_matmul_heuristic_output: Optional[bool] = None
if num_bases is not None:
self.weight = Parameter(
torch.empty(num_bases, in_channels[0], out_channels))
self.comp = Parameter(torch.empty(num_relations, num_bases))
elif num_blocks is not None:
assert (in_channels[0] % num_blocks == 0
and out_channels % num_blocks == 0)
self.weight = Parameter(
torch.empty(num_relations, num_blocks,
in_channels[0] // num_blocks,
out_channels // num_blocks))
self.register_parameter('comp', None)
self.weight = Parameter(
torch.empty(num_relations, in_channels[0], out_channels))
self.register_parameter('comp', None)
if root_weight:
self.root = Parameter(torch.empty(in_channels[1], out_channels))
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.empty(out_channels))
self.register_parameter('bias', None)
[docs] def reset_parameters(self):
[docs] def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
edge_index: Adj, edge_type: OptTensor = None):
r"""Runs the forward pass of the module.
x (torch.Tensor or tuple, optional): The input node features.
Can be either a :obj:`[num_nodes, in_channels]` node feature
matrix, or an optional one-dimensional node index tensor (in
which case input features are treated as trainable node
Furthermore, :obj:`x` can be of type :obj:`tuple` denoting
source and destination node features.
edge_index (torch.Tensor or SparseTensor): The edge indices.
edge_type (torch.Tensor, optional): The one-dimensional relation
type/index for each edge in :obj:`edge_index`.
Should be only :obj:`None` in case :obj:`edge_index` is of type
:class:`torch_sparse.SparseTensor`. (default: :obj:`None`)
# Convert input features to a pair of node features or node indices.
x_l: OptTensor = None
if isinstance(x, tuple):
x_l = x[0]
x_l = x
if x_l is None:
x_l = torch.arange(self.in_channels_l, device=self.weight.device)
x_r: Tensor = x_l
if isinstance(x, tuple):
x_r = x[1]
size = (x_l.size(0), x_r.size(0))
if isinstance(edge_index, SparseTensor):
edge_type =
assert edge_type is not None
# propagate_type: (x: Tensor, edge_type_ptr: OptTensor)
out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)
weight = self.weight
if self.num_bases is not None: # Basis-decomposition =================
weight = (self.comp @ weight.view(self.num_bases, -1)).view(
self.num_relations, self.in_channels_l, self.out_channels)
if self.num_blocks is not None: # Block-diagonal-decomposition =====
if not torch.is_floating_point(
x_r) and self.num_blocks is not None:
raise ValueError('Block-diagonal decomposition not supported '
'for non-continuous input features.')
for i in range(self.num_relations):
tmp = masked_edge_index(edge_index, edge_type == i)
h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size)
h = h.view(-1, weight.size(1), weight.size(2))
h = torch.einsum('abc,bcd->abd', h, weight[i])
out = out + h.contiguous().view(-1, self.out_channels)
else: # No regularization/Basis-decomposition ========================
use_segment_matmul = torch_geometric.backend.use_segment_matmul
# If `use_segment_matmul` is not specified, use a simple heuristic
# to determine whether `segment_matmul` can speed up computation
# given the observed input sizes:
if use_segment_matmul is None:
segment_count = scatter(torch.ones_like(edge_type), edge_type,
self._use_segment_matmul_heuristic_output = (
assert self._use_segment_matmul_heuristic_output is not None
use_segment_matmul = self._use_segment_matmul_heuristic_output
if (use_segment_matmul and torch_geometric.typing.WITH_SEGMM
and not is_compiling() and self.num_bases is None
and x_l.is_floating_point()
and isinstance(edge_index, Tensor)):
if not self.is_sorted:
if (edge_type[1:] < edge_type[:-1]).any():
edge_type, perm = index_sort(
edge_type, max_value=self.num_relations)
edge_index = edge_index[:, perm]
edge_type_ptr = index2ptr(edge_type, self.num_relations)
out = self.propagate(edge_index, x=x_l,
edge_type_ptr=edge_type_ptr, size=size)
for i in range(self.num_relations):
tmp = masked_edge_index(edge_index, edge_type == i)
if not torch.is_floating_point(x_r):
out = out + self.propagate(
x=weight[i, x_l],
h = self.propagate(tmp, x=x_l, edge_type_ptr=None,
out = out + (h @ weight[i])
root = self.root
if root is not None:
if not torch.is_floating_point(x_r):
out = out + root[x_r]
out = out + x_r @ root
if self.bias is not None:
out = out + self.bias
return out
def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor:
if (torch_geometric.typing.WITH_SEGMM and not is_compiling()
and edge_type_ptr is not None):
# TODO Re-weight according to edge type degree for `aggr=mean`.
return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight)
return x_j
def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
if isinstance(adj_t, SparseTensor):
adj_t = adj_t.set_value(None)
return spmm(adj_t, x, reduce=self.aggr)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, num_relations={self.num_relations})')
[docs]class FastRGCNConv(RGCNConv):
r"""See :class:`RGCNConv`."""
[docs] def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
edge_index: Adj, edge_type: OptTensor = None):
self.fuse = False
assert self.aggr in ['add', 'sum', 'mean']
# Convert input features to a pair of node features or node indices.
x_l: OptTensor = None
if isinstance(x, tuple):
x_l = x[0]
x_l = x
if x_l is None:
x_l = torch.arange(self.in_channels_l, device=self.weight.device)
x_r: Tensor = x_l
if isinstance(x, tuple):
x_r = x[1]
size = (x_l.size(0), x_r.size(0))
# propagate_type: (x: Tensor, edge_type: OptTensor)
out = self.propagate(edge_index, x=x_l, edge_type=edge_type, size=size)
root = self.root
if root is not None:
if not torch.is_floating_point(x_r):
out = out + root[x_r]
out = out + x_r @ root
if self.bias is not None:
out = out + self.bias
return out
def message(self, x_j: Tensor, edge_type: Tensor,
edge_index_j: Tensor) -> Tensor:
weight = self.weight
if self.num_bases is not None: # Basis-decomposition =================
weight = (self.comp @ weight.view(self.num_bases, -1)).view(
self.num_relations, self.in_channels_l, self.out_channels)
if self.num_blocks is not None: # Block-diagonal-decomposition =======
if not torch.is_floating_point(x_j):
raise ValueError('Block-diagonal decomposition not supported '
'for non-continuous input features.')
weight = weight[edge_type].view(-1, weight.size(2), weight.size(3))
x_j = x_j.view(-1, 1, weight.size(1))
return torch.bmm(x_j, weight).view(-1, self.out_channels)
else: # No regularization/Basis-decomposition ========================
if not torch.is_floating_point(x_j):
weight_index = edge_type * weight.size(1) + edge_index_j
return weight.view(-1, self.out_channels)[weight_index]
return torch.bmm(x_j.unsqueeze(-2), weight[edge_type]).squeeze(-2)
def aggregate(self, inputs: Tensor, edge_type: Tensor, index: Tensor,
dim_size: Optional[int] = None) -> Tensor:
# Compute normalization in separation for each `edge_type`.
if self.aggr == 'mean':
norm = one_hot(edge_type, self.num_relations, dtype=inputs.dtype)
norm = scatter(norm, index, dim=0, dim_size=dim_size)[index]
norm = torch.gather(norm, 1, edge_type.view(-1, 1))
norm = 1. / norm.clamp_(1.)
inputs = norm * inputs
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)