Source code for torch_geometric.nn.conv.hetero_conv

import warnings
from typing import Dict, List, Optional

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.module_dict import ModuleDict
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.utils.hetero import check_add_self_loops


def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
    if len(xs) == 0:
        return None
    elif aggr is None:
        return torch.stack(xs, dim=1)
    elif len(xs) == 1:
        return xs[0]
    elif aggr == "cat":
        return torch.cat(xs, dim=-1)
    else:
        out = torch.stack(xs, dim=0)
        out = getattr(torch, aggr)(out, dim=0)
        out = out[0] if isinstance(out, tuple) else out
        return out


[docs]class HeteroConv(torch.nn.Module): r"""A generic wrapper for computing graph convolution on heterogeneous graphs. This layer will pass messages from source nodes to target nodes based on the bipartite GNN layer given for a specific edge type. If multiple relations point to the same destination, their results will be aggregated according to :attr:`aggr`. In comparison to :meth:`torch_geometric.nn.to_hetero`, this layer is especially useful if you want to apply different message passing modules for different edge types. .. code-block:: python hetero_conv = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), ('paper', 'written_by', 'author'): GATConv((-1, -1), 64), }, aggr='sum') out_dict = hetero_conv(x_dict, edge_index_dict) print(list(out_dict.keys())) >>> ['paper', 'author'] Args: convs (Dict[Tuple[str, str, str], MessagePassing]): A dictionary holding a bipartite :class:`~torch_geometric.nn.conv.MessagePassing` layer for each individual edge type. aggr (str, optional): The aggregation scheme to use for grouping node embeddings generated by different relations (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"cat"`, :obj:`None`). (default: :obj:`"sum"`) """ def __init__( self, convs: Dict[EdgeType, MessagePassing], aggr: Optional[str] = "sum", ): super().__init__() for edge_type, module in convs.items(): check_add_self_loops(module, [edge_type]) src_node_types = set([key[0] for key in convs.keys()]) dst_node_types = set([key[-1] for key in convs.keys()]) if len(src_node_types - dst_node_types) > 0: warnings.warn( f"There exist node types ({src_node_types - dst_node_types}) " f"whose representations do not get updated during message " f"passing as they do not occur as destination type in any " f"edge type. This may lead to unexpected behavior.") self.convs = ModuleDict(convs) self.aggr = aggr
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for conv in self.convs.values(): conv.reset_parameters()
[docs] def forward( self, *args_dict, **kwargs_dict, ) -> Dict[NodeType, Tensor]: r"""Runs the forward pass of the module. Args: x_dict (Dict[str, torch.Tensor]): A dictionary holding node feature information for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A dictionary holding graph connectivity information for each individual edge type, either as a :class:`torch.Tensor` of shape :obj:`[2, num_edges]` or a :class:`torch_sparse.SparseTensor`. *args_dict (optional): Additional forward arguments of invididual :class:`torch_geometric.nn.conv.MessagePassing` layers. **kwargs_dict (optional): Additional forward arguments of individual :class:`torch_geometric.nn.conv.MessagePassing` layers. For example, if a specific GNN layer at edge type :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a forward argument, then you can pass them to :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via :obj:`edge_attr_dict = { edge_type: edge_attr }`. """ out_dict: Dict[str, List[Tensor]] = {} for edge_type, conv in self.convs.items(): src, rel, dst = edge_type has_edge_level_arg = False args = [] for value_dict in args_dict: if edge_type in value_dict: has_edge_level_arg = True args.append(value_dict[edge_type]) elif src == dst and src in value_dict: args.append(value_dict[src]) elif src in value_dict or dst in value_dict: args.append(( value_dict.get(src, None), value_dict.get(dst, None), )) kwargs = {} for arg, value_dict in kwargs_dict.items(): if not arg.endswith('_dict'): raise ValueError( f"Keyword arguments in '{self.__class__.__name__}' " f"need to end with '_dict' (got '{arg}')") arg = arg[:-5] # `{*}_dict` if edge_type in value_dict: has_edge_level_arg = True kwargs[arg] = value_dict[edge_type] elif src == dst and src in value_dict: kwargs[arg] = value_dict[src] elif src in value_dict or dst in value_dict: kwargs[arg] = ( value_dict.get(src, None), value_dict.get(dst, None), ) if not has_edge_level_arg: continue out = conv(*args, **kwargs) if dst not in out_dict: out_dict[dst] = [out] else: out_dict[dst].append(out) for key, value in out_dict.items(): out_dict[key] = group(value, self.aggr) return out_dict
def __repr__(self) -> str: return f'{self.__class__.__name__}(num_relations={len(self.convs)})'