torch_geometric.nn.conv.GPSConv

class GPSConv(channels: int, conv: Optional[MessagePassing], heads: int = 1, dropout: float = 0.0, act: str = 'relu', act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[str] = 'batch_norm', norm_kwargs: Optional[Dict[str, Any]] = None, attn_type: str = 'multihead', attn_kwargs: Optional[Dict[str, Any]] = None)[source]

Bases: Module

The general, powerful, scalable (GPS) graph transformer layer from the “Recipe for a General, Powerful, Scalable Graph Transformer” paper.

The GPS layer is based on a 3-part recipe:

  1. Inclusion of positional (PE) and structural encodings (SE) to the input features (done in a pre-processing step via torch_geometric.transforms).

  2. A local message passing layer (MPNN) that operates on the input graph.

  3. A global attention layer that operates on the entire graph.

Note

For an example of using GPSConv, see examples/graph_gps.py.

Parameters:
  • channels (int) – Size of each input sample.

  • conv (MessagePassing, optional) – The local message passing layer.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • dropout (float, optional) – Dropout probability of intermediate embeddings. (default: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (str or Callable, optional) – The normalization function to use. (default: "batch_norm")

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • attn_type (str) – Global attention type, multihead or performer. (default: multihead)

  • attn_kwargs (Dict[str, Any], optional) – Arguments passed to the attention layer. (default: None)

forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], batch: Optional[Tensor] = None, **kwargs) Tensor[source]

Runs the forward pass of the module.

reset_parameters()[source]

Resets all learnable parameters of the module.