torch_geometric.nn.conv.CuGraphSAGEConv
- class CuGraphSAGEConv(in_channels: int, out_channels: int, aggr: str = 'mean', normalize: bool = False, root_weight: bool = True, project: bool = False, bias: bool = True)[source]
Bases:
CuGraphModule
The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper.
CuGraphSAGEConv
is an optimized version ofSAGEConv
based on thecugraph-ops
package that fuses message passing computation for accelerated execution and lower memory footprint.- forward(x: Tensor, edge_index: EdgeIndex, max_num_neighbors: Optional[int] = None) Tensor [source]
Runs the forward pass of the module.
- Parameters:
x (torch.Tensor) – The node features.
edge_index (EdgeIndex) – The edge indices.
max_num_neighbors (int, optional) – The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph. When not given, the value will be computed on-the-fly, leading to slightly worse performance. (default:
None
)