Source code for torch_geometric.nn.conv.gps_conv

import inspect
from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.attention import PerformerAttention
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch


[docs]class GPSConv(torch.nn.Module): r"""The general, powerful, scalable (GPS) graph transformer layer from the `"Recipe for a General, Powerful, Scalable Graph Transformer" <https://arxiv.org/abs/2205.12454>`_ paper. The GPS layer is based on a 3-part recipe: 1. Inclusion of positional (PE) and structural encodings (SE) to the input features (done in a pre-processing step via :class:`torch_geometric.transforms`). 2. A local message passing layer (MPNN) that operates on the input graph. 3. A global attention layer that operates on the entire graph. .. note:: For an example of using :class:`GPSConv`, see `examples/graph_gps.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ graph_gps.py>`_. Args: channels (int): Size of each input sample. conv (MessagePassing, optional): The local message passing layer. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) dropout (float, optional): Dropout probability of intermediate embeddings. (default: :obj:`0.`) act (str or Callable, optional): The non-linear activation function to use. (default: :obj:`"relu"`) act_kwargs (Dict[str, Any], optional): Arguments passed to the respective activation function defined by :obj:`act`. (default: :obj:`None`) norm (str or Callable, optional): The normalization function to use. (default: :obj:`"batch_norm"`) norm_kwargs (Dict[str, Any], optional): Arguments passed to the respective normalization function defined by :obj:`norm`. (default: :obj:`None`) attn_type (str): Global attention type, :obj:`multihead` or :obj:`performer`. (default: :obj:`multihead`) attn_kwargs (Dict[str, Any], optional): Arguments passed to the attention layer. (default: :obj:`None`) """ def __init__( self, channels: int, conv: Optional[MessagePassing], heads: int = 1, dropout: float = 0.0, act: str = 'relu', act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[str] = 'batch_norm', norm_kwargs: Optional[Dict[str, Any]] = None, attn_type: str = 'multihead', attn_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__() self.channels = channels self.conv = conv self.heads = heads self.dropout = dropout self.attn_type = attn_type attn_kwargs = attn_kwargs or {} if attn_type == 'multihead': self.attn = torch.nn.MultiheadAttention( channels, heads, batch_first=True, **attn_kwargs, ) elif attn_type == 'performer': self.attn = PerformerAttention( channels=channels, heads=heads, **attn_kwargs, ) else: # TODO: Support BigBird raise ValueError(f'{attn_type} is not supported') self.mlp = Sequential( Linear(channels, channels * 2), activation_resolver(act, **(act_kwargs or {})), Dropout(dropout), Linear(channels * 2, channels), Dropout(dropout), ) norm_kwargs = norm_kwargs or {} self.norm1 = normalization_resolver(norm, channels, **norm_kwargs) self.norm2 = normalization_resolver(norm, channels, **norm_kwargs) self.norm3 = normalization_resolver(norm, channels, **norm_kwargs) self.norm_with_batch = False if self.norm1 is not None: signature = inspect.signature(self.norm1.forward) self.norm_with_batch = 'batch' in signature.parameters
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" if self.conv is not None: self.conv.reset_parameters() self.attn._reset_parameters() reset(self.mlp) if self.norm1 is not None: self.norm1.reset_parameters() if self.norm2 is not None: self.norm2.reset_parameters() if self.norm3 is not None: self.norm3.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: Adj, batch: Optional[torch.Tensor] = None, **kwargs, ) -> Tensor: r"""Runs the forward pass of the module.""" hs = [] if self.conv is not None: # Local MPNN. h = self.conv(x, edge_index, **kwargs) h = F.dropout(h, p=self.dropout, training=self.training) h = h + x if self.norm1 is not None: if self.norm_with_batch: h = self.norm1(h, batch=batch) else: h = self.norm1(h) hs.append(h) # Global attention transformer-style model. h, mask = to_dense_batch(x, batch) if isinstance(self.attn, torch.nn.MultiheadAttention): h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False) elif isinstance(self.attn, PerformerAttention): h = self.attn(h, mask=mask) h = h[mask] h = F.dropout(h, p=self.dropout, training=self.training) h = h + x # Residual connection. if self.norm2 is not None: if self.norm_with_batch: h = self.norm2(h, batch=batch) else: h = self.norm2(h) hs.append(h) out = sum(hs) # Combine local and global outputs. out = out + self.mlp(out) if self.norm3 is not None: if self.norm_with_batch: out = self.norm3(out, batch=batch) else: out = self.norm3(out) return out
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.channels}, ' f'conv={self.conv}, heads={self.heads}, ' f'attn_type={self.attn_type})')