Source code for torch_geometric.nn.conv.hetero_conv

import warnings
from collections import defaultdict
from typing import Dict, Optional

from torch import Tensor
from torch.nn import Module, ModuleDict

from torch_geometric.nn.conv.hgt_conv import group
from torch_geometric.typing import Adj, EdgeType, NodeType
from torch_geometric.utils.hetero import check_add_self_loops

[docs]class HeteroConv(Module): r"""A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to :attr:`aggr`. In comparison to :meth:`torch_geometric.nn.to_hetero`, this layer is especially useful if you want to apply different message passing modules for different edge types. .. code-block:: python hetero_conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'written_by', 'author'): GATConv((-1, -1), 64), }, aggr='sum') out_dict = hetero_conv(x_dict, edge_index_dict) print(list(out_dict.keys())) >>> ['paper', 'author'] Args: convs (Dict[Tuple[str, str, str], Module]): A dictionary holding a bipartite :class:`~torch_geometric.nn.conv.MessagePassing` layer for each individual edge type. aggr (string, optional): The aggregation scheme to use for grouping node embeddings generated by different relations. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`None`). (default: :obj:`"sum"`) """ def __init__(self, convs: Dict[EdgeType, Module], aggr: Optional[str] = "sum"): super().__init__() for edge_type, module in convs.items(): check_add_self_loops(module, [edge_type]) src_node_types = set([key[0] for key in convs.keys()]) dst_node_types = set([key[-1] for key in convs.keys()]) if len(src_node_types - dst_node_types) > 0: warnings.warn( f"There exist node types ({src_node_types - dst_node_types}) " f"whose representations do not get updated during message " f"passing as they do not occur as destination type in any " f"edge type. This may lead to unexpected behaviour.") self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()}) self.aggr = aggr
[docs] def reset_parameters(self): for conv in self.convs.values(): conv.reset_parameters()
[docs] def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Adj], *args_dict, **kwargs_dict, ) -> Dict[NodeType, Tensor]: r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding node feature information for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary holding graph connectivity information for each individual edge type. *args_dict (optional): Additional forward arguments of invididual :class:`torch_geometric.nn.conv.MessagePassing` layers. **kwargs_dict (optional): Additional forward arguments of individual :class:`torch_geometric.nn.conv.MessagePassing` layers. For example, if a specific GNN layer at edge type :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a forward argument, then you can pass them to :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via :obj:`edge_attr_dict = { edge_type: edge_attr }`. """ out_dict = defaultdict(list) for edge_type, edge_index in edge_index_dict.items(): src, rel, dst = edge_type str_edge_type = '__'.join(edge_type) if str_edge_type not in self.convs: continue args = [] for value_dict in args_dict: if edge_type in value_dict: args.append(value_dict[edge_type]) elif src == dst and src in value_dict: args.append(value_dict[src]) elif src in value_dict or dst in value_dict: args.append( (value_dict.get(src, None), value_dict.get(dst, None))) kwargs = {} for arg, value_dict in kwargs_dict.items(): arg = arg[:-5] # `{*}_dict` if edge_type in value_dict: kwargs[arg] = value_dict[edge_type] elif src == dst and src in value_dict: kwargs[arg] = value_dict[src] elif src in value_dict or dst in value_dict: kwargs[arg] = (value_dict.get(src, None), value_dict.get(dst, None)) conv = self.convs[str_edge_type] if src == dst: out = conv(x_dict[src], edge_index, *args, **kwargs) else: out = conv((x_dict[src], x_dict[dst]), edge_index, *args, **kwargs) out_dict[dst].append(out) for key, value in out_dict.items(): out_dict[key] = group(value, self.aggr) return out_dict
def __repr__(self) -> str: return f'{self.__class__.__name__}(num_relations={len(self.convs)})'