torch_geometric.nn.conv.HeteroConv

class HeteroConv(convs: Dict[Tuple[str, str, str], MessagePassing], aggr: Optional[str] = 'sum')[source]

Bases: Module

A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to aggr. In comparison to torch_geometric.nn.to_hetero(), this layer is especially useful if you want to apply different message passing modules for different edge types.

hetero_conv = HeteroConv({
    ('paper', 'cites', 'paper'): GCNConv(-1, 64),
    ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
    ('paper', 'written_by', 'author'): GATConv((-1, -1), 64),
}, aggr='sum')

out_dict = hetero_conv(x_dict, edge_index_dict)

print(list(out_dict.keys()))
>>> ['paper', 'author']
Parameters:
  • convs (Dict[Tuple[str, str, str], MessagePassing]) – A dictionary holding a bipartite MessagePassing layer for each individual edge type.

  • aggr (str, optional) – The aggregation scheme to use for grouping node embeddings generated by different relations ("sum", "mean", "min", "max", "cat", None). (default: "sum")

forward(*args_dict, **kwargs_dict) Dict[str, Tensor][source]

Runs the forward pass of the module.

Parameters:
  • x_dict (Dict[str, torch.Tensor]) – A dictionary holding node feature information for each individual node type.

  • edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a torch.Tensor of shape [2, num_edges] or a torch_sparse.SparseTensor.

  • *args_dict (optional) – Additional forward arguments of individual torch_geometric.nn.conv.MessagePassing layers.

  • **kwargs_dict (optional) – Additional forward arguments of individual torch_geometric.nn.conv.MessagePassing layers. For example, if a specific GNN layer at edge type edge_type expects edge attributes edge_attr as a forward argument, then you can pass them to forward() via edge_attr_dict = { edge_type: edge_attr }.

reset_parameters()[source]

Resets all learnable parameters of the module.