torch_geometric.nn.conv.HeteroConv
- class HeteroConv(convs: Dict[Tuple[str, str, str], MessagePassing], aggr: Optional[str] = 'sum')[source]
Bases:
Module
A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to
aggr
. In comparison totorch_geometric.nn.to_hetero()
, this layer is especially useful if you want to apply different message passing modules for different edge types.hetero_conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'written_by', 'author'): GATConv((-1, -1), 64), }, aggr='sum') out_dict = hetero_conv(x_dict, edge_index_dict) print(list(out_dict.keys())) >>> ['paper', 'author']
- Parameters:
convs (Dict[Tuple[str, str, str], MessagePassing]) – A dictionary holding a bipartite
MessagePassing
layer for each individual edge type.aggr (str, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations (
"sum"
,"mean"
,"min"
,"max"
,"cat"
,None
). (default:"sum"
)
- forward(*args_dict, **kwargs_dict) Dict[str, Tensor] [source]
Runs the forward pass of the module.
- Parameters:
x_dict (Dict[str, torch.Tensor]) – A dictionary holding node feature information for each individual node type.
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a
torch.Tensor
of shape[2, num_edges]
or atorch_sparse.SparseTensor
.*args_dict (optional) – Additional forward arguments of individual
torch_geometric.nn.conv.MessagePassing
layers.**kwargs_dict (optional) – Additional forward arguments of individual
torch_geometric.nn.conv.MessagePassing
layers. For example, if a specific GNN layer at edge typeedge_type
expects edge attributesedge_attr
as a forward argument, then you can pass them toforward()
viaedge_attr_dict = { edge_type: edge_attr }
.