torch_geometric.nn.conv.MessagePassing
- class MessagePassing(aggr: Optional[Union[str, List[str], Aggregation]] = 'add', *, aggr_kwargs: Optional[Dict[str, Any]] = None, flow: str = 'source_to_target', node_dim: int = -2, decomposed_layers: int = 1, **kwargs)[source]
Bases:
Module
Base class for creating message passing layers.
Message passing layers follow the form
\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),\]where \(\bigoplus\) denotes a differentiable, permutation invariant function, e.g., sum, mean, min, max or mul, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs. See here for the accompanying tutorial.
- Parameters:
aggr (str or [str] or Aggregation, optional) – The aggregation scheme to use, e.g.,
"add"
,"sum"
"mean"
,"min"
,"max"
or"mul"
. In addition, can be anyAggregation
module (or any string that automatically resolves to it). If given as a list, will make use of multiple aggregations in which different outputs will get concatenated in the last dimension. If set toNone
, theMessagePassing
instantiation is expected to implement its own aggregation logic viaaggregate()
. (default:"add"
)aggr_kwargs (Dict[str, Any], optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default:
None
)flow (str, optional) – The flow direction of message passing (
"source_to_target"
or"target_to_source"
). (default:"source_to_target"
)node_dim (int, optional) – The axis along which to propagate. (default:
-2
)decomposed_layers (int, optional) – The number of feature decomposition layers, as introduced in the “Optimizing Memory Efficiency of Graph Neural Networks on Edge Computing Platforms” paper. Feature decomposition reduces the peak memory usage by slicing the feature dimensions into separated feature decomposition layers during GNN aggregation. This method can accelerate GNN execution on CPU-based platforms (e.g., 2-3x speedup on the
Reddit
dataset) for common GNN models such asGCN
,GraphSAGE
,GIN
, etc. However, this method is not applicable to all GNN operators available, in particular for operators in which message computation can not easily be decomposed, e.g. in attention-based GNNs. The selection of the optimal value ofdecomposed_layers
depends both on the specific graph dataset and available hardware resources. A value of2
is suitable in most cases. Although the peak memory usage is directly associated with the granularity of feature decomposition, the same is not necessarily true for execution speedups. (default:1
)
- propagate(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs)[source]
The initial call to start propagating messages.
- Parameters:
edge_index (torch.Tensor or SparseTensor) – A
torch.Tensor
, atorch_sparse.SparseTensor
or atorch.sparse.Tensor
that defines the underlying graph connectivity/message passing flow.edge_index
holds the indices of a general (sparse) assignment matrix of shape[N, M]
. Ifedge_index
is atorch.Tensor
, itsdtype
should betorch.long
and its shape needs to be defined as[2, num_messages]
where messages from nodes inedge_index[0]
are sent to nodes inedge_index[1]
(in caseflow="source_to_target"
). Ifedge_index
is atorch_sparse.SparseTensor
or atorch.sparse.Tensor
, its sparse indices(row, col)
should relate torow = edge_index[1]
andcol = edge_index[0]
. The major difference between both formats is that we need to input the transposed sparse adjacency matrix intopropagate()
.size ((int, int), optional) – The size
(N, M)
of the assignment matrix in caseedge_index
is atorch.Tensor
. If set toNone
, the size will be automatically inferred and assumed to be quadratic. This argument is ignored in caseedge_index
is atorch_sparse.SparseTensor
or atorch.sparse.Tensor
. (default:None
)**kwargs – Any additional data which is needed to construct and aggregate messages, and to update node embeddings.
- edge_updater(edge_index: Union[Tensor, SparseTensor], **kwargs)[source]
The initial call to compute or update features for each edge in the graph.
- Parameters:
edge_index (torch.Tensor or SparseTensor) – A
torch.Tensor
, atorch_sparse.SparseTensor
or atorch.sparse.Tensor
that defines the underlying graph connectivity/message passing flow. Seepropagate()
for more information.**kwargs – Any additional data which is needed to compute or update features for each edge in the graph.
- message(x_j: Tensor) Tensor [source]
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index
. This function can take any argument as input which was initially passed topropagate()
. Furthermore, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- aggregate(inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) Tensor [source]
Aggregates messages from neighbors as \(\bigoplus_{j \in \mathcal{N}(i)}\).
Takes in the output of message computation as first argument and any argument which was initially passed to
propagate()
.By default, this function will delegate its call to the underlying
Aggregation
module to reduce messages as specified in__init__()
by theaggr
argument.
- message_and_aggregate(adj_t: Union[SparseTensor, Tensor]) Tensor [source]
Fuses computations of
message()
andaggregate()
into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and propagation takes place based on atorch_sparse.SparseTensor
or atorch.sparse.Tensor
.
- update(inputs: Tensor) Tensor [source]
Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to
propagate()
.
- edge_update() Tensor [source]
Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to
edge_updater()
. Furthermore, tensors passed toedge_updater()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- register_propagate_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module.
The hook will be called every time before
propagate()
is invoked. It should have the following signature:hook(module, inputs) -> None or modified input
The hook can modify the input. Input keyword arguments are passed to the hook as a dictionary in
inputs[-1]
.Returns a
torch.utils.hooks.RemovableHandle
that can be used to remove the added hook by callinghandle.remove()
.
- register_propagate_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module.
The hook will be called every time after
propagate()
has computed an output. It should have the following signature:hook(module, inputs, output) -> None or modified output
The hook can modify the output. Input keyword arguments are passed to the hook as a dictionary in
inputs[-1]
.Returns a
torch.utils.hooks.RemovableHandle
that can be used to remove the added hook by callinghandle.remove()
.
- register_message_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module. The hook will be called every time before
message()
is invoked. Seeregister_propagate_forward_pre_hook()
for more information.
- register_message_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module. The hook will be called every time after
message()
has computed an output. Seeregister_propagate_forward_hook()
for more information.
- register_aggregate_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module. The hook will be called every time before
aggregate()
is invoked. Seeregister_propagate_forward_pre_hook()
for more information.
- register_aggregate_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module. The hook will be called every time after
aggregate()
has computed an output. Seeregister_propagate_forward_hook()
for more information.
- register_message_and_aggregate_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module. The hook will be called every time before
message_and_aggregate()
is invoked. Seeregister_propagate_forward_pre_hook()
for more information.
- register_message_and_aggregate_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module. The hook will be called every time after
message_and_aggregate()
has computed an output. Seeregister_propagate_forward_hook()
for more information.
- register_edge_update_forward_pre_hook(hook: Callable) RemovableHandle [source]
Registers a forward pre-hook on the module. The hook will be called every time before
edge_update()
is invoked. Seeregister_propagate_forward_pre_hook()
for more information.
- register_edge_update_forward_hook(hook: Callable) RemovableHandle [source]
Registers a forward hook on the module. The hook will be called every time after
edge_update()
has computed an output. Seeregister_propagate_forward_hook()
for more information.
- jittable(typing: Optional[str] = None) MessagePassing [source]
Analyzes the
MessagePassing
instance and produces a new jittable module that can be used in combination withtorch.jit.script()
.