torch_geometric.nn.conv.MessagePassing

class MessagePassing(aggr: Optional[Union[str, List[str], Aggregation]] = 'sum', *, aggr_kwargs: Optional[Dict[str, Any]] = None, flow: str = 'source_to_target', node_dim: int = -2, decomposed_layers: int = 1)[source]

Bases: Module

Base class for creating message passing layers.

Message passing layers follow the form

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]

where \(\bigoplus\) denotes a differentiable, permutation invariant function, e.g., sum, mean, min, max or mul, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs. See here for the accompanying tutorial.

Parameters:
  • aggr (str or [str] or Aggregation, optional) – The aggregation scheme to use, e.g., "sum" "mean", "min", "max" or "mul". In addition, can be any Aggregation module (or any string that automatically resolves to it). If given as a list, will make use of multiple aggregations in which different outputs will get concatenated in the last dimension. If set to None, the MessagePassing instantiation is expected to implement its own aggregation logic via aggregate(). (default: "add")

  • aggr_kwargs (Dict[str, Any], optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: None)

  • flow (str, optional) – The flow direction of message passing ("source_to_target" or "target_to_source"). (default: "source_to_target")

  • node_dim (int, optional) – The axis along which to propagate. (default: -2)

  • decomposed_layers (int, optional) – The number of feature decomposition layers, as introduced in the “Optimizing Memory Efficiency of Graph Neural Networks on Edge Computing Platforms” paper. Feature decomposition reduces the peak memory usage by slicing the feature dimensions into separated feature decomposition layers during GNN aggregation. This method can accelerate GNN execution on CPU-based platforms (e.g., 2-3x speedup on the Reddit dataset) for common GNN models such as GCN, GraphSAGE, GIN, etc. However, this method is not applicable to all GNN operators available, in particular for operators in which message computation can not easily be decomposed, e.g. in attention-based GNNs. The selection of the optimal value of decomposed_layers depends both on the specific graph dataset and available hardware resources. A value of 2 is suitable in most cases. Although the peak memory usage is directly associated with the granularity of feature decomposition, the same is not necessarily true for execution speedups. (default: 1)

reset_parameters() None[source]

Resets all learnable parameters of the module.

forward(*args: Any, **kwargs: Any) Any[source]

Runs the forward pass of the module.

propagate(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs: Any) Tensor[source]

The initial call to start propagating messages.

Parameters:
  • edge_index (torch.Tensor or SparseTensor) – A torch.Tensor, a torch_sparse.SparseTensor or a torch.sparse.Tensor that defines the underlying graph connectivity/message passing flow. edge_index holds the indices of a general (sparse) assignment matrix of shape [N, M]. If edge_index is a torch.Tensor, its dtype should be torch.long and its shape needs to be defined as [2, num_messages] where messages from nodes in edge_index[0] are sent to nodes in edge_index[1] (in case flow="source_to_target"). If edge_index is a torch_sparse.SparseTensor or a torch.sparse.Tensor, its sparse indices (row, col) should relate to row = edge_index[1] and col = edge_index[0]. The major difference between both formats is that we need to input the transposed sparse adjacency matrix into propagate().

  • size ((int, int), optional) – The size (N, M) of the assignment matrix in case edge_index is a torch.Tensor. If set to None, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in case edge_index is a torch_sparse.SparseTensor or a torch.sparse.Tensor. (default: None)

  • **kwargs – Any additional data which is needed to construct and aggregate messages, and to update node embeddings.

message(x_j: Tensor) Tensor[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

aggregate(inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) Tensor[source]

Aggregates messages from neighbors as \(\bigoplus_{j \in \mathcal{N}(i)}\).

Takes in the output of message computation as first argument and any argument which was initially passed to propagate().

By default, this function will delegate its call to the underlying Aggregation module to reduce messages as specified in __init__() by the aggr argument.

abstract message_and_aggregate(edge_index: Union[Tensor, SparseTensor]) Tensor[source]

Fuses computations of message() and aggregate() into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and propagation takes place based on a torch_sparse.SparseTensor or a torch.sparse.Tensor.

update(inputs: Tensor) Tensor[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

edge_updater(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs: Any) Tensor[source]

The initial call to compute or update features for each edge in the graph.

Parameters:
  • edge_index (torch.Tensor or SparseTensor) – A torch.Tensor, a torch_sparse.SparseTensor or a torch.sparse.Tensor that defines the underlying graph connectivity/message passing flow. See propagate() for more information.

  • size ((int, int), optional) – The size (N, M) of the assignment matrix in case edge_index is a torch.Tensor. If set to None, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in case edge_index is a torch_sparse.SparseTensor or a torch.sparse.Tensor. (default: None)

  • **kwargs – Any additional data which is needed to compute or update features for each edge in the graph.

abstract edge_update() Tensor[source]

Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to edge_updater(). Furthermore, tensors passed to edge_updater() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

register_propagate_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module.

The hook will be called every time before propagate() is invoked. It should have the following signature:

hook(module, inputs) -> None or modified input

The hook can modify the input. Input keyword arguments are passed to the hook as a dictionary in inputs[-1].

Returns a torch.utils.hooks.RemovableHandle that can be used to remove the added hook by calling handle.remove().

register_propagate_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module.

The hook will be called every time after propagate() has computed an output. It should have the following signature:

hook(module, inputs, output) -> None or modified output

The hook can modify the output. Input keyword arguments are passed to the hook as a dictionary in inputs[-1].

Returns a torch.utils.hooks.RemovableHandle that can be used to remove the added hook by calling handle.remove().

register_message_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module. The hook will be called every time before message() is invoked. See register_propagate_forward_pre_hook() for more information.

register_message_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module. The hook will be called every time after message() has computed an output. See register_propagate_forward_hook() for more information.

register_aggregate_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module. The hook will be called every time before aggregate() is invoked. See register_propagate_forward_pre_hook() for more information.

register_aggregate_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module. The hook will be called every time after aggregate() has computed an output. See register_propagate_forward_hook() for more information.

register_message_and_aggregate_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module. The hook will be called every time before message_and_aggregate() is invoked. See register_propagate_forward_pre_hook() for more information.

register_message_and_aggregate_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module. The hook will be called every time after message_and_aggregate() has computed an output. See register_propagate_forward_hook() for more information.

register_edge_update_forward_pre_hook(hook: Callable) RemovableHandle[source]

Registers a forward pre-hook on the module. The hook will be called every time before edge_update() is invoked. See register_propagate_forward_pre_hook() for more information.

register_edge_update_forward_hook(hook: Callable) RemovableHandle[source]

Registers a forward hook on the module. The hook will be called every time after edge_update() has computed an output. See register_propagate_forward_hook() for more information.

jittable(typing: Optional[str] = None) MessagePassing[source]

Analyzes the MessagePassing instance and produces a new jittable module that can be used in combination with torch.jit.script().

Note

jittable() is deprecated and a no-op from 2.5 onwards.