import warnings
from typing import Dict, List, Optional
import torch
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.module_dict import ModuleDict
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.utils.hetero import check_add_self_loops
def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
if len(xs) == 0:
return None
elif aggr is None:
return torch.stack(xs, dim=1)
elif len(xs) == 1:
return xs[0]
elif aggr == "cat":
return torch.cat(xs, dim=-1)
else:
out = torch.stack(xs, dim=0)
out = getattr(torch, aggr)(out, dim=0)
out = out[0] if isinstance(out, tuple) else out
return out
[docs]class HeteroConv(torch.nn.Module):
r"""A generic wrapper for computing graph convolution on heterogeneous
graphs.
This layer will pass messages from source nodes to target nodes based on
the bipartite GNN layer given for a specific edge type.
If multiple relations point to the same destination, their results will be
aggregated according to :attr:`aggr`.
In comparison to :meth:`torch_geometric.nn.to_hetero`, this layer is
especially useful if you want to apply different message passing modules
for different edge types.
.. code-block:: python
hetero_conv = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'written_by', 'author'): GATConv((-1, -1), 64),
}, aggr='sum')
out_dict = hetero_conv(x_dict, edge_index_dict)
print(list(out_dict.keys()))
>>> ['paper', 'author']
Args:
convs (Dict[Tuple[str, str, str], MessagePassing]): A dictionary
holding a bipartite
:class:`~torch_geometric.nn.conv.MessagePassing` layer for each
individual edge type.
aggr (str, optional): The aggregation scheme to use for grouping node
embeddings generated by different relations
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"cat"`, :obj:`None`). (default: :obj:`"sum"`)
"""
def __init__(
self,
convs: Dict[EdgeType, MessagePassing],
aggr: Optional[str] = "sum",
):
super().__init__()
for edge_type, module in convs.items():
check_add_self_loops(module, [edge_type])
src_node_types = {key[0] for key in convs.keys()}
dst_node_types = {key[-1] for key in convs.keys()}
if len(src_node_types - dst_node_types) > 0:
warnings.warn(
f"There exist node types ({src_node_types - dst_node_types}) "
f"whose representations do not get updated during message "
f"passing as they do not occur as destination type in any "
f"edge type. This may lead to unexpected behavior.")
self.convs = ModuleDict(convs)
self.aggr = aggr
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for conv in self.convs.values():
conv.reset_parameters()
[docs] def forward(
self,
*args_dict,
**kwargs_dict,
) -> Dict[NodeType, Tensor]:
r"""Runs the forward pass of the module.
Args:
x_dict (Dict[str, torch.Tensor]): A dictionary holding node feature
information for each individual node type.
edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A
dictionary holding graph connectivity information for each
individual edge type, either as a :class:`torch.Tensor` of
shape :obj:`[2, num_edges]` or a
:class:`torch_sparse.SparseTensor`.
*args_dict (optional): Additional forward arguments of individual
:class:`torch_geometric.nn.conv.MessagePassing` layers.
**kwargs_dict (optional): Additional forward arguments of
individual :class:`torch_geometric.nn.conv.MessagePassing`
layers.
For example, if a specific GNN layer at edge type
:obj:`edge_type` expects edge attributes :obj:`edge_attr` as a
forward argument, then you can pass them to
:meth:`~torch_geometric.nn.conv.HeteroConv.forward` via
:obj:`edge_attr_dict = { edge_type: edge_attr }`.
"""
out_dict: Dict[str, List[Tensor]] = {}
for edge_type, conv in self.convs.items():
src, rel, dst = edge_type
has_edge_level_arg = False
args = []
for value_dict in args_dict:
if edge_type in value_dict:
has_edge_level_arg = True
args.append(value_dict[edge_type])
elif src == dst and src in value_dict:
args.append(value_dict[src])
elif src in value_dict or dst in value_dict:
args.append((
value_dict.get(src, None),
value_dict.get(dst, None),
))
kwargs = {}
for arg, value_dict in kwargs_dict.items():
if not arg.endswith('_dict'):
raise ValueError(
f"Keyword arguments in '{self.__class__.__name__}' "
f"need to end with '_dict' (got '{arg}')")
arg = arg[:-5] # `{*}_dict`
if edge_type in value_dict:
has_edge_level_arg = True
kwargs[arg] = value_dict[edge_type]
elif src == dst and src in value_dict:
kwargs[arg] = value_dict[src]
elif src in value_dict or dst in value_dict:
kwargs[arg] = (
value_dict.get(src, None),
value_dict.get(dst, None),
)
if not has_edge_level_arg:
continue
out = conv(*args, **kwargs)
if dst not in out_dict:
out_dict[dst] = [out]
else:
out_dict[dst].append(out)
for key, value in out_dict.items():
out_dict[key] = group(value, self.aggr)
return out_dict
def __repr__(self) -> str:
return f'{self.__class__.__name__}(num_relations={len(self.convs)})'