Source code for torch_geometric.nn.models.meta
from typing import Optional, Tuple
import torch
from torch import Tensor
[docs]class MetaLayer(torch.nn.Module):
r"""A meta layer for building any kind of graph network, inspired by the
`"Relational Inductive Biases, Deep Learning, and Graph Networks"
<https://arxiv.org/abs/1806.01261>`_ paper.
A graph network takes a graph as input and returns an updated graph as
output (with same connectivity).
The input graph has node features :obj:`x`, edge features :obj:`edge_attr`
as well as graph-level features :obj:`u`.
The output graph has the same structure, but updated features.
Edge features, node features as well as global features are updated by
calling the modules :obj:`edge_model`, :obj:`node_model` and
:obj:`global_model`, respectively.
To allow for batch-wise graph processing, all callable functions take an
additional argument :obj:`batch`, which determines the assignment of
edges or nodes to their specific graphs.
Args:
edge_model (torch.nn.Module, optional): A callable which updates a
graph's edge features based on its source and target node features,
its current edge features and its global features.
(default: :obj:`None`)
node_model (torch.nn.Module, optional): A callable which updates a
graph's node features based on its current node features, its graph
connectivity, its edge features and its global features.
(default: :obj:`None`)
global_model (torch.nn.Module, optional): A callable which updates a
graph's global features based on its node features, its graph
connectivity, its edge features and its current global features.
(default: :obj:`None`)
.. code-block:: python
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.utils import scatter
from torch_geometric.nn import MetaLayer
class EdgeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
def forward(self, src, dst, edge_attr, u, batch):
# src, dst: [E, F_x], where E is the number of edges.
# edge_attr: [E, F_e]
# u: [B, F_u], where B is the number of graphs.
# batch: [E] with max entry B - 1.
out = torch.cat([src, dst, edge_attr, u[batch]], 1)
return self.edge_mlp(out)
class NodeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
def forward(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
row, col = edge_index
out = torch.cat([x[row], edge_attr], dim=1)
out = self.node_mlp_1(out)
out = scatter(out, col, dim=0, dim_size=x.size(0),
reduce='mean')
out = torch.cat([x, out, u[batch]], dim=1)
return self.node_mlp_2(out)
class GlobalModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
def forward(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
out = torch.cat([
u,
scatter(x, batch, dim=0, reduce='mean'),
], dim=1)
return self.global_mlp(out)
op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())
x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
"""
def __init__(
self,
edge_model: Optional[torch.nn.Module] = None,
node_model: Optional[torch.nn.Module] = None,
global_model: Optional[torch.nn.Module] = None,
):
super().__init__()
self.edge_model = edge_model
self.node_model = node_model
self.global_model = global_model
self.reset_parameters()
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for item in [self.node_model, self.edge_model, self.global_model]:
if hasattr(item, 'reset_parameters'):
item.reset_parameters()
[docs] def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
u: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
r"""Forward pass.
Args:
x (torch.Tensor): The node features.
edge_index (torch.Tensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
u (torch.Tensor, optional): The global graph features.
(default: :obj:`None`)
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific graph. (default: :obj:`None`)
"""
row = edge_index[0]
col = edge_index[1]
if self.edge_model is not None:
edge_attr = self.edge_model(x[row], x[col], edge_attr, u,
batch if batch is None else batch[row])
if self.node_model is not None:
x = self.node_model(x, edge_index, edge_attr, u, batch)
if self.global_model is not None:
u = self.global_model(x, edge_index, edge_attr, u, batch)
return x, edge_attr, u
def __repr__(self) -> str:
return (f'{self.__class__.__name__}(\n'
f' edge_model={self.edge_model},\n'
f' node_model={self.node_model},\n'
f' global_model={self.global_model}\n'
f')')