Source code for torch_geometric.nn.conv.cugraph.rgcn_conv

from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric import EdgeIndex
from torch_geometric.nn.conv.cugraph import CuGraphModule
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE
from torch_geometric.nn.inits import glorot, zeros

try:
    if LEGACY_MODE:
        from pylibcugraphops.torch.autograd import \
            agg_hg_basis_n2n_post as RGCNConvAgg
    else:
        from pylibcugraphops.pytorch.operators import \
            agg_hg_basis_n2n_post as RGCNConvAgg
except ImportError:
    pass


[docs]class CuGraphRGCNConv(CuGraphModule): # pragma: no cover r"""The relational graph convolutional operator from the `"Modeling Relational Data with Graph Convolutional Networks" <https://arxiv.org/abs/1703.06103>`_ paper. :class:`CuGraphRGCNConv` is an optimized version of :class:`~torch_geometric.nn.conv.RGCNConv` based on the :obj:`cugraph-ops` package that fuses message passing computation for accelerated execution and lower memory footprint. """ def __init__(self, in_channels: int, out_channels: int, num_relations: int, num_bases: Optional[int] = None, aggr: str = 'mean', root_weight: bool = True, bias: bool = True): super().__init__() if aggr not in ['sum', 'add', 'mean']: raise ValueError(f"Aggregation function must be either 'mean' " f"or 'sum' (got '{aggr}')") self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations self.num_bases = num_bases self.aggr = aggr self.root_weight = root_weight dim_root_weight = 1 if root_weight else 0 if num_bases is not None: self.weight = Parameter( torch.empty(num_bases + dim_root_weight, in_channels, out_channels)) self.comp = Parameter(torch.empty(num_relations, num_bases)) else: self.weight = Parameter( torch.empty(num_relations + dim_root_weight, in_channels, out_channels)) self.register_parameter('comp', None) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): end = -1 if self.root_weight else None glorot(self.weight[:end]) glorot(self.comp) if self.root_weight: glorot(self.weight[-1]) zeros(self.bias)
[docs] def forward( self, x: Tensor, edge_index: EdgeIndex, edge_type: Tensor, max_num_neighbors: Optional[int] = None, ) -> Tensor: r"""Runs the forward pass of the module. Args: x (torch.Tensor): The node features. edge_index (EdgeIndex): The edge indices. edge_type (torch.Tensor): The edge type. max_num_neighbors (int, optional): The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph.. When not given, the value will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) """ graph = self.get_typed_cugraph(edge_index, edge_type, self.num_relations, max_num_neighbors) out = RGCNConvAgg(x, self.comp, graph, concat_own=self.root_weight, norm_by_out_degree=bool(self.aggr == 'mean')) out = out @ self.weight.view(-1, self.out_channels) if self.bias is not None: out = out + self.bias return out
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_relations={self.num_relations})')