Source code for torch_geometric.nn.conv.hgt_conv

import math
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import HeteroDictLinear, HeteroLinear
from torch_geometric.nn.inits import ones
from torch_geometric.nn.parameter_dict import ParameterDict
from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType
from torch_geometric.utils import softmax
from torch_geometric.utils.hetero import construct_bipartite_edge_index


[docs]class HGTConv(MessagePassing): r"""The Heterogeneous Graph Transformer (HGT) operator from the `"Heterogeneous Graph Transformer" <https://arxiv.org/abs/2003.01332>`_ paper. .. note:: For an example of using HGT, see `examples/hetero/hgt_dblp.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ hetero/hgt_dblp.py>`_. Args: in_channels (int or Dict[str, int]): Size of each input sample of every node type, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. See :meth:`torch_geometric.data.HeteroData.metadata` for more information. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__( self, in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Metadata, heads: int = 1, **kwargs, ): super().__init__(aggr='add', node_dim=0, **kwargs) if out_channels % heads != 0: raise ValueError(f"'out_channels' (got {out_channels}) must be " f"divisible by the number of heads (got {heads})") if not isinstance(in_channels, dict): in_channels = {node_type: in_channels for node_type in metadata[0]} self.in_channels = in_channels self.out_channels = out_channels self.heads = heads self.node_types = metadata[0] self.edge_types = metadata[1] self.edge_types_map = { edge_type: i for i, edge_type in enumerate(metadata[1]) } self.dst_node_types = set([key[-1] for key in self.edge_types]) self.kqv_lin = HeteroDictLinear(self.in_channels, self.out_channels * 3) self.out_lin = HeteroDictLinear(self.out_channels, self.out_channels, types=self.node_types) dim = out_channels // heads num_types = heads * len(self.edge_types) self.k_rel = HeteroLinear(dim, dim, num_types, bias=False, is_sorted=True) self.v_rel = HeteroLinear(dim, dim, num_types, bias=False, is_sorted=True) self.skip = ParameterDict({ node_type: Parameter(torch.empty(1)) for node_type in self.node_types }) self.p_rel = ParameterDict() for edge_type in self.edge_types: edge_type = '__'.join(edge_type) self.p_rel[edge_type] = Parameter(torch.empty(1, heads)) self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() self.kqv_lin.reset_parameters() self.out_lin.reset_parameters() self.k_rel.reset_parameters() self.v_rel.reset_parameters() ones(self.skip) ones(self.p_rel)
def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]: """Concatenates a dictionary of features.""" cumsum = 0 outs: List[Tensor] = [] offset: Dict[str, int] = {} for key, x in x_dict.items(): outs.append(x) offset[key] = cumsum cumsum += x.size(0) return torch.cat(outs, dim=0), offset def _construct_src_node_feat( self, k_dict: Dict[str, Tensor], v_dict: Dict[str, Tensor], edge_index_dict: Dict[EdgeType, Adj] ) -> Tuple[Tensor, Tensor, Dict[EdgeType, int]]: """Constructs the source node representations.""" cumsum = 0 num_edge_types = len(self.edge_types) H, D = self.heads, self.out_channels // self.heads # Flatten into a single tensor with shape [num_edge_types * heads, D]: ks: List[Tensor] = [] vs: List[Tensor] = [] type_list: List[Tensor] = [] offset: Dict[EdgeType] = {} for edge_type in edge_index_dict.keys(): src = edge_type[0] N = k_dict[src].size(0) offset[edge_type] = cumsum cumsum += N # construct type_vec for curr edge_type with shape [H, D] edge_type_offset = self.edge_types_map[edge_type] type_vec = torch.arange(H, dtype=torch.long).view(-1, 1).repeat( 1, N) * num_edge_types + edge_type_offset type_list.append(type_vec) ks.append(k_dict[src]) vs.append(v_dict[src]) ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D) vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D) type_vec = torch.cat(type_list, dim=1).flatten() k = self.k_rel(ks, type_vec).view(H, -1, D).transpose(0, 1) v = self.v_rel(vs, type_vec).view(H, -1, D).transpose(0, 1) return k, v, offset
[docs] def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Adj] # Support both. ) -> Dict[NodeType, Optional[Tensor]]: r"""Runs the forward pass of the module. Args: x_dict (Dict[str, torch.Tensor]): A dictionary holding input node features for each individual node type. edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A dictionary holding graph connectivity information for each individual edge type, either as a :class:`torch.Tensor` of shape :obj:`[2, num_edges]` or a :class:`torch_sparse.SparseTensor`. :rtype: :obj:`Dict[str, Optional[torch.Tensor]]` - The output node embeddings for each node type. In case a node type does not receive any message, its output will be set to :obj:`None`. """ F = self.out_channels H = self.heads D = F // H k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {} # Compute K, Q, V over node types: kqv_dict = self.kqv_lin(x_dict) for key, val in kqv_dict.items(): k, q, v = torch.tensor_split(val, 3, dim=1) k_dict[key] = k.view(-1, H, D) q_dict[key] = q.view(-1, H, D) v_dict[key] = v.view(-1, H, D) q, dst_offset = self._cat(q_dict) k, v, src_offset = self._construct_src_node_feat( k_dict, v_dict, edge_index_dict) edge_index, edge_attr = construct_bipartite_edge_index( edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel, num_nodes=k.size(0)) out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr) # Reconstruct output node embeddings dict: for node_type, start_offset in dst_offset.items(): end_offset = start_offset + q_dict[node_type].size(0) if node_type in self.dst_node_types: out_dict[node_type] = out[start_offset:end_offset] # Transform output node embeddings: a_dict = self.out_lin({ k: torch.nn.functional.gelu(v) if v is not None else v for k, v in out_dict.items() }) # Iterate over node types: for node_type, out in out_dict.items(): out = a_dict[node_type] if out.size(-1) == x_dict[node_type].size(-1): alpha = self.skip[node_type].sigmoid() out = alpha * out + (1 - alpha) * x_dict[node_type] out_dict[node_type] = out return out_dict
def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]) -> Tensor: alpha = (q_i * k_j).sum(dim=-1) * edge_attr alpha = alpha / math.sqrt(q_i.size(-1)) alpha = softmax(alpha, index, ptr, size_i) out = v_j * alpha.view(-1, self.heads, 1) return out.view(-1, self.out_channels) def __repr__(self) -> str: return (f'{self.__class__.__name__}(-1, {self.out_channels}, ' f'heads={self.heads})')