torch_geometric.nn.conv.GPSConv
- class GPSConv(channels: int, conv: Optional[MessagePassing], heads: int = 1, dropout: float = 0.0, act: str = 'relu', act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[str] = 'batch_norm', norm_kwargs: Optional[Dict[str, Any]] = None, attn_type: str = 'multihead', attn_kwargs: Optional[Dict[str, Any]] = None)[source]
Bases:
Module
The general, powerful, scalable (GPS) graph transformer layer from the “Recipe for a General, Powerful, Scalable Graph Transformer” paper.
The GPS layer is based on a 3-part recipe:
Inclusion of positional (PE) and structural encodings (SE) to the input features (done in a pre-processing step via
torch_geometric.transforms
).A local message passing layer (MPNN) that operates on the input graph.
A global attention layer that operates on the entire graph.
Note
For an example of using
GPSConv
, see examples/graph_gps.py.- Parameters:
channels (int) – Size of each input sample.
conv (MessagePassing, optional) – The local message passing layer.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)dropout (float, optional) – Dropout probability of intermediate embeddings. (default:
0.
)act (str or Callable, optional) – The non-linear activation function to use. (default:
"relu"
)act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by
act
. (default:None
)norm (str or Callable, optional) – The normalization function to use. (default:
"batch_norm"
)norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by
norm
. (default:None
)attn_type (str) – Global attention type,
multihead
orperformer
. (default:multihead
)attn_kwargs (Dict[str, Any], optional) – Arguments passed to the attention layer. (default:
None
)