Source code for torch_geometric.nn.conv.han_conv

from typing import Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.inits import glorot, reset
from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType
from torch_geometric.utils import softmax


def group(xs: List[Tensor], q: nn.Parameter,
          k_lin: nn.Module) -> Optional[Tensor]:
    if len(xs) == 0:
        return None
    else:
        num_edge_types = len(xs)
        out = torch.stack(xs)
        attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1)
        attn = F.softmax(attn_score, dim=0)
        out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0)
        return out


[docs]class HANConv(MessagePassing): r""" The Heterogenous Graph Attention Operator from the `"Heterogenous Graph Attention Network" <https://arxiv.org/pdf/1903.07293.pdf>`_ paper. .. note:: For an example of using HANConv, see `examples/hetero/han_imdb.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ hetero/han_imdb.py>`_. Args: in_channels (int or Dict[str, int]): Size of each input sample of every node type, or :obj:`-1` to derive the size from the first input(s) to the forward method. out_channels (int): Size of each output sample. metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. See :meth:`torch_geometric.data.HeteroData.metadata` for more information. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__( self, in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Metadata, heads: int = 1, negative_slope=0.2, dropout: float = 0.0, **kwargs, ): super().__init__(aggr='add', node_dim=0, **kwargs) if not isinstance(in_channels, dict): in_channels = {node_type: in_channels for node_type in metadata[0]} self.heads = heads self.in_channels = in_channels self.out_channels = out_channels self.negative_slope = negative_slope self.metadata = metadata self.dropout = dropout self.k_lin = nn.Linear(out_channels, out_channels) self.q = nn.Parameter(torch.Tensor(1, out_channels)) self.proj = nn.ModuleDict() for node_type, in_channels in self.in_channels.items(): self.proj[node_type] = Linear(in_channels, out_channels) self.lin_src = nn.ParameterDict() self.lin_dst = nn.ParameterDict() dim = out_channels // heads for edge_type in metadata[1]: edge_type = '__'.join(edge_type) self.lin_src[edge_type] = nn.Parameter(torch.Tensor(1, heads, dim)) self.lin_dst[edge_type] = nn.Parameter(torch.Tensor(1, heads, dim)) self.reset_parameters()
[docs] def reset_parameters(self): reset(self.proj) glorot(self.lin_src) glorot(self.lin_dst) self.k_lin.reset_parameters() glorot(self.q)
[docs] def forward( self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Adj]) -> Dict[NodeType, Optional[Tensor]]: r""" Args: x_dict (Dict[str, Tensor]): A dictionary holding input node features for each individual node type. edge_index_dict (Dict[str, Union[Tensor, SparseTensor]]): A dictionary holding graph connectivity information for each individual edge type, either as a :obj:`torch.LongTensor` of shape :obj:`[2, num_edges]` or a :obj:`torch_sparse.SparseTensor`. :rtype: :obj:`Dict[str, Optional[Tensor]]` - The output node embeddings for each node type. In case a node type does not receive any message, its output will be set to :obj:`None`. """ H, D = self.heads, self.out_channels // self.heads x_node_dict, out_dict = {}, {} # Iterate over node types: for node_type, x_node in x_dict.items(): x_node_dict[node_type] = self.proj[node_type](x_node).view( -1, H, D) out_dict[node_type] = [] # Iterate over edge types: for edge_type, edge_index in edge_index_dict.items(): src_type, _, dst_type = edge_type edge_type = '__'.join(edge_type) lin_src = self.lin_src[edge_type] lin_dst = self.lin_dst[edge_type] x_dst = x_node_dict[dst_type] alpha_src = (x_node_dict[src_type] * lin_src).sum(dim=-1) alpha_dst = (x_dst * lin_dst).sum(dim=-1) alpha = (alpha_src, alpha_dst) # propagate_type: (x_dst: Tensor, alpha: PairTensor) out = self.propagate(edge_index, x_dst=x_dst, alpha=alpha, size=None) out = F.relu(out) out_dict[dst_type].append(out) # iterate over node types: for node_type, outs in out_dict.items(): out = group(outs, self.q, self.k_lin) if out is None: out_dict[node_type] = None continue out_dict[node_type] = out return out_dict
def message(self, x_dst_i: Tensor, alpha_i: Tensor, alpha_j: Tensor, index: Tensor, ptr: Optional[Tensor], size_i: Optional[int]) -> Tensor: alpha = alpha_j + alpha_i alpha = F.leaky_relu(alpha, self.negative_slope) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) out = x_dst_i * alpha.view(-1, self.heads, 1) return out.view(-1, self.out_channels) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.out_channels}, ' f'heads={self.heads})')