torch_geometric.nn.models.MetaLayer
- class MetaLayer(edge_model: Optional[Module] = None, node_model: Optional[Module] = None, global_model: Optional[Module] = None)[source]
Bases:
Module
A meta layer for building any kind of graph network, inspired by the “Relational Inductive Biases, Deep Learning, and Graph Networks” paper.
A graph network takes a graph as input and returns an updated graph as output (with same connectivity). The input graph has node features
x
, edge featuresedge_attr
as well as graph-level featuresu
. The output graph has the same structure, but updated features.Edge features, node features as well as global features are updated by calling the modules
edge_model
,node_model
andglobal_model
, respectively.To allow for batch-wise graph processing, all callable functions take an additional argument
batch
, which determines the assignment of edges or nodes to their specific graphs.- Parameters:
edge_model (torch.nn.Module, optional) – A callable which updates a graph’s edge features based on its source and target node features, its current edge features and its global features. (default:
None
)node_model (torch.nn.Module, optional) – A callable which updates a graph’s node features based on its current node features, its graph connectivity, its edge features and its global features. (default:
None
)global_model (torch.nn.Module, optional) – A callable which updates a graph’s global features based on its node features, its graph connectivity, its edge features and its current global features. (default:
None
)
from torch.nn import Sequential as Seq, Linear as Lin, ReLU from torch_geometric.utils import scatter from torch_geometric.nn import MetaLayer class EdgeModel(torch.nn.Module): def __init__(self): super().__init__() self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, src, dst, edge_attr, u, batch): # src, dst: [E, F_x], where E is the number of edges. # edge_attr: [E, F_e] # u: [B, F_u], where B is the number of graphs. # batch: [E] with max entry B - 1. out = torch.cat([src, dst, edge_attr, u[batch]], 1) return self.edge_mlp(out) class NodeModel(torch.nn.Module): def __init__(self): super().__init__() self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. row, col = edge_index out = torch.cat([x[row], edge_attr], dim=1) out = self.node_mlp_1(out) out = scatter(out, col, dim=0, dim_size=x.size(0), reduce='mean') out = torch.cat([x, out, u[batch]], dim=1) return self.node_mlp_2(out) class GlobalModel(torch.nn.Module): def __init__(self): super().__init__() self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...)) def forward(self, x, edge_index, edge_attr, u, batch): # x: [N, F_x], where N is the number of nodes. # edge_index: [2, E] with max entry N - 1. # edge_attr: [E, F_e] # u: [B, F_u] # batch: [N] with max entry B - 1. out = torch.cat([ u, scatter(x, batch, dim=0, reduce='mean'), ], dim=1) return self.global_mlp(out) op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel()) x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)
- forward(x: Tensor, edge_index: Tensor, edge_attr: Optional[Tensor] = None, u: Optional[Tensor] = None, batch: Optional[Tensor] = None) Tuple[Tensor, Optional[Tensor], Optional[Tensor]] [source]
Forward pass.
- Parameters:
x (torch.Tensor) – The node features.
edge_index (torch.Tensor) – The edge indices.
edge_attr (torch.Tensor, optional) – The edge features. (default:
None
)u (torch.Tensor, optional) – The global graph features. (default:
None
)batch (torch.Tensor, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific graph. (default:
None
)