from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.inits import glorot, reset
from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType
from torch_geometric.utils import softmax
def group(xs: List[Tensor], q: nn.Parameter,
k_lin: nn.Module) -> Optional[Tensor]:
if len(xs) == 0:
return None
else:
num_edge_types = len(xs)
out = torch.stack(xs)
attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1)
attn = F.softmax(attn_score, dim=0)
out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0)
return out
[docs]class HANConv(MessagePassing):
r"""
The Heterogenous Graph Attention Operator from the
`"Heterogenous Graph Attention Network"
<https://arxiv.org/pdf/1903.07293.pdf>`_ paper.
.. note::
For an example of using HANConv, see `examples/hetero/han_imdb.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
hetero/han_imdb.py>`_.
Args:
in_channels (int or Dict[str, int]): Size of each input sample of every
node type, or :obj:`-1` to derive the size from the first input(s)
to the forward method.
out_channels (int): Size of each output sample.
metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata
of the heterogeneous graph, *i.e.* its node and edge types given
by a list of strings and a list of string triplets, respectively.
See :meth:`torch_geometric.data.HeteroData.metadata` for more
information.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
negative_slope (float, optional): LeakyReLU angle of the negative
slope. (default: :obj:`0.2`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(
self,
in_channels: Union[int, Dict[str, int]],
out_channels: int,
metadata: Metadata,
heads: int = 1,
negative_slope=0.2,
dropout: float = 0.0,
**kwargs,
):
super().__init__(aggr='add', node_dim=0, **kwargs)
if not isinstance(in_channels, dict):
in_channels = {node_type: in_channels for node_type in metadata[0]}
self.heads = heads
self.in_channels = in_channels
self.out_channels = out_channels
self.negative_slope = negative_slope
self.metadata = metadata
self.dropout = dropout
self.k_lin = nn.Linear(out_channels, out_channels)
self.q = nn.Parameter(torch.Tensor(1, out_channels))
self.proj = nn.ModuleDict()
for node_type, in_channels in self.in_channels.items():
self.proj[node_type] = Linear(in_channels, out_channels)
self.lin_src = nn.ParameterDict()
self.lin_dst = nn.ParameterDict()
dim = out_channels // heads
for edge_type in metadata[1]:
edge_type = '__'.join(edge_type)
self.lin_src[edge_type] = nn.Parameter(torch.Tensor(1, heads, dim))
self.lin_dst[edge_type] = nn.Parameter(torch.Tensor(1, heads, dim))
self.reset_parameters()
[docs] def reset_parameters(self):
reset(self.proj)
glorot(self.lin_src)
glorot(self.lin_dst)
self.k_lin.reset_parameters()
glorot(self.q)
[docs] def forward(
self, x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType,
Adj]) -> Dict[NodeType, Optional[Tensor]]:
r"""
Args:
x_dict (Dict[str, Tensor]): A dictionary holding input node
features for each individual node type.
edge_index_dict: (Dict[str, Union[Tensor, SparseTensor]]): A
dictionary holding graph connectivity information for each
individual edge type, either as a :obj:`torch.LongTensor` of
shape :obj:`[2, num_edges]` or a
:obj:`torch_sparse.SparseTensor`.
:rtype: :obj:`Dict[str, Optional[Tensor]]` - The ouput node embeddings
for each node type.
In case a node type does not receive any message, its output will
be set to :obj:`None`.
"""
H, D = self.heads, self.out_channels // self.heads
x_node_dict, out_dict = {}, {}
# Iterate over node types:
for node_type, x_node in x_dict.items():
x_node_dict[node_type] = self.proj[node_type](x_node).view(
-1, H, D)
out_dict[node_type] = []
# Iterate over edge types:
for edge_type, edge_index in edge_index_dict.items():
src_type, _, dst_type = edge_type
edge_type = '__'.join(edge_type)
lin_src = self.lin_src[edge_type]
lin_dst = self.lin_dst[edge_type]
x_dst = x_node_dict[dst_type]
alpha_src = (x_node_dict[src_type] * lin_src).sum(dim=-1)
alpha_dst = (x_dst * lin_dst).sum(dim=-1)
alpha = (alpha_src, alpha_dst)
# propagate_type: (x_dst: Tensor, alpha: PairTensor)
out = self.propagate(edge_index, x_dst=x_dst, alpha=alpha,
size=None)
out = F.relu(out)
out_dict[dst_type].append(out)
# iterate over node types:
for node_type, outs in out_dict.items():
out = group(outs, self.q, self.k_lin)
if out is None:
out_dict[node_type] = None
continue
out_dict[node_type] = out
return out_dict
def message(self, x_dst_i: Tensor, alpha_i: Tensor, alpha_j: Tensor,
index: Tensor, ptr: Optional[Tensor],
size_i: Optional[int]) -> Tensor:
alpha = alpha_j + alpha_i
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = x_dst_i * alpha.view(-1, self.heads, 1)
return out.view(-1, self.out_channels)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.out_channels}, '
f'heads={self.heads})')