Source code for torch_geometric.nn.conv.cugraph.sage_conv

from typing import Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear

from torch_geometric import EdgeIndex
from torch_geometric.nn.conv.cugraph import CuGraphModule
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE

try:
    if LEGACY_MODE:
        from pylibcugraphops.torch.autograd import \
            agg_concat_n2n as SAGEConvAgg
    else:
        from pylibcugraphops.pytorch.operators import \
            agg_concat_n2n as SAGEConvAgg
except ImportError:
    pass


[docs]class CuGraphSAGEConv(CuGraphModule): # pragma: no cover r"""The GraphSAGE operator from the `"Inductive Representation Learning on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper. :class:`CuGraphSAGEConv` is an optimized version of :class:`~torch_geometric.nn.conv.SAGEConv` based on the :obj:`cugraph-ops` package that fuses message passing computation for accelerated execution and lower memory footprint. """ def __init__( self, in_channels: int, out_channels: int, aggr: str = 'mean', normalize: bool = False, root_weight: bool = True, project: bool = False, bias: bool = True, ): super().__init__() if aggr not in ['mean', 'sum', 'min', 'max']: raise ValueError(f"Aggregation function must be either 'mean', " f"'sum', 'min' or 'max' (got '{aggr}')") self.in_channels = in_channels self.out_channels = out_channels self.aggr = aggr self.normalize = normalize self.root_weight = root_weight self.project = project if self.project: self.pre_lin = Linear(in_channels, in_channels, bias=True) if self.root_weight: self.lin = Linear(2 * in_channels, out_channels, bias=bias) else: self.lin = Linear(in_channels, out_channels, bias=bias) self.reset_parameters()
[docs] def reset_parameters(self): if self.project: self.pre_lin.reset_parameters() self.lin.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: EdgeIndex, max_num_neighbors: Optional[int] = None, ) -> Tensor: graph = self.get_cugraph(edge_index, max_num_neighbors) if self.project: x = self.pre_lin(x).relu() out = SAGEConvAgg(x, graph, self.aggr) if self.root_weight: out = self.lin(out) else: out = self.lin(out[:, :self.in_channels]) if self.normalize: out = F.normalize(out, p=2., dim=-1) return out
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, aggr={self.aggr})')