Source code for torch_geometric.nn.conv.rgat_conv

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter, ReLU

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, ones, zeros
from torch_geometric.typing import Adj, OptTensor, Size, SparseTensor
from torch_geometric.utils import is_torch_sparse_tensor, scatter, softmax
from torch_geometric.utils.sparse import set_sparse_value


[docs]class RGATConv(MessagePassing): r"""The relational graph attentional operator from the `"Relational Graph Attention Networks" <https://arxiv.org/abs/1904.05811>`_ paper. Here, attention logits :math:`\mathbf{a}^{(r)}_{i,j}` are computed for each relation type :math:`r` with the help of both query and key kernels, *i.e.* .. math:: \mathbf{q}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{Q}^{(r)} \quad \textrm{and} \quad \mathbf{k}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{K}^{(r)}. Two schemes have been proposed to compute attention logits :math:`\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r`: **Additive attention** .. math:: \mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j) or **multiplicative attention** .. math:: \mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j. If the graph has multi-dimensional edge features :math:`\mathbf{e}^{(r)}_{i,j}`, the attention logits :math:`\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r` are computed as .. math:: \mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j + \mathbf{W}_2^{(r)}\mathbf{e}^{(r)}_{i,j}) or .. math:: \mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j \cdot \mathbf{W}_2^{(r)} \mathbf{e}^{(r)}_{i,j}, respectively. The attention coefficients :math:`\alpha^{(r)}_{i,j}` for each relation type :math:`r` are then obtained via two different attention mechanisms: The **within-relation** attention mechanism .. math:: \alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{k \in \mathcal{N}_r(i)} \exp(\mathbf{a}^{(r)}_{i,k})} or the **across-relation** attention mechanism .. math:: \alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{r^{\prime} \in \mathcal{R}} \sum_{k \in \mathcal{N}_{r^{\prime}}(i)} \exp(\mathbf{a}^{(r^{\prime})}_{i,k})} where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types. Edge type needs to be a one-dimensional :obj:`torch.long` tensor which stores a relation identifier :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge. To enhance the discriminative power of attention-based GNNs, this layer further implements four different cardinality preservation options as proposed in the `"Improving Attention Mechanism in Graph Neural Networks via Cardinality Preservation" <https://arxiv.org/abs/1907.02204>`_ paper: .. math:: \text{additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j + \mathcal{W} \odot \sum_{j \in \mathcal{N}_r(i)} \mathbf{x}^{(r)}_j \text{scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= \psi(|\mathcal{N}_r(i)|) \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j \text{f-additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} (\alpha^{(r)}_{i,j} + 1) \cdot \mathbf{x}^{(r)}_j \text{f-scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= |\mathcal{N}_r(i)| \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j * If :obj:`attention_mode="additive-self-attention"` and :obj:`concat=True`, the layer outputs :obj:`heads * out_channels` features for each node. * If :obj:`attention_mode="multiplicative-self-attention"` and :obj:`concat=True`, the layer outputs :obj:`heads * dim * out_channels` features for each node. * If :obj:`attention_mode="additive-self-attention"` and :obj:`concat=False`, the layer outputs :obj:`out_channels` features for each node. * If :obj:`attention_mode="multiplicative-self-attention"` and :obj:`concat=False`, the layer outputs :obj:`dim * out_channels` features for each node. Please make sure to set the :obj:`in_channels` argument of the next layer accordingly if more than one instance of this layer is used. .. note:: For an example of using :class:`RGATConv`, see `examples/rgat.py <https://github.com/pyg-team/pytorch_geometric/blob /master/examples/rgat.py>`_. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. num_relations (int): Number of relations. num_bases (int, optional): If set, this layer will use the basis-decomposition regularization scheme where :obj:`num_bases` denotes the number of bases to use. (default: :obj:`None`) num_blocks (int, optional): If set, this layer will use the block-diagonal-decomposition regularization scheme where :obj:`num_blocks` denotes the number of blocks to use. (default: :obj:`None`) mod (str, optional): The cardinality preservation option to use. (:obj:`"additive"`, :obj:`"scaled"`, :obj:`"f-additive"`, :obj:`"f-scaled"`, :obj:`None`). (default: :obj:`None`) attention_mechanism (str, optional): The attention mechanism to use (:obj:`"within-relation"`, :obj:`"across-relation"`). (default: :obj:`"across-relation"`) attention_mode (str, optional): The mode to calculate attention logits. (:obj:`"additive-self-attention"`, :obj:`"multiplicative-self-attention"`). (default: :obj:`"additive-self-attention"`) heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) dim (int): Number of dimensions for query and key kernels. (default: :obj:`1`) concat (bool, optional): If set to :obj:`False`, the multi-head attentions are averaged instead of concatenated. (default: :obj:`True`) negative_slope (float, optional): LeakyReLU angle of the negative slope. (default: :obj:`0.2`) dropout (float, optional): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: :obj:`0`) edge_dim (int, optional): Edge feature dimensionality (in case there are any). (default: :obj:`None`) bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ _alpha: OptTensor def __init__( self, in_channels: int, out_channels: int, num_relations: int, num_bases: Optional[int] = None, num_blocks: Optional[int] = None, mod: Optional[str] = None, attention_mechanism: str = "across-relation", attention_mode: str = "additive-self-attention", heads: int = 1, dim: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, edge_dim: Optional[int] = None, bias: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) self.heads = heads self.negative_slope = negative_slope self.dropout = dropout self.mod = mod self.activation = ReLU() self.concat = concat self.attention_mode = attention_mode self.attention_mechanism = attention_mechanism self.dim = dim self.edge_dim = edge_dim self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations self.num_bases = num_bases self.num_blocks = num_blocks mod_types = ['additive', 'scaled', 'f-additive', 'f-scaled'] if (self.attention_mechanism != "within-relation" and self.attention_mechanism != "across-relation"): raise ValueError('attention mechanism must either be ' '"within-relation" or "across-relation"') if (self.attention_mode != "additive-self-attention" and self.attention_mode != "multiplicative-self-attention"): raise ValueError('attention mode must either be ' '"additive-self-attention" or ' '"multiplicative-self-attention"') if self.attention_mode == "additive-self-attention" and self.dim > 1: raise ValueError('"additive-self-attention" mode cannot be ' 'applied when value of d is greater than 1. ' 'Use "multiplicative-self-attention" instead.') if self.dropout > 0.0 and self.mod in mod_types: raise ValueError('mod must be None with dropout value greater ' 'than 0 in order to sample attention ' 'coefficients stochastically') if num_bases is not None and num_blocks is not None: raise ValueError('Can not apply both basis-decomposition and ' 'block-diagonal-decomposition at the same time.') # The learnable parameters to compute both attention logits and # attention coefficients: self.q = Parameter( torch.empty(self.heads * self.out_channels, self.heads * self.dim)) self.k = Parameter( torch.empty(self.heads * self.out_channels, self.heads * self.dim)) if bias and concat: self.bias = Parameter( torch.empty(self.heads * self.dim * self.out_channels)) elif bias and not concat: self.bias = Parameter(torch.empty(self.dim * self.out_channels)) else: self.register_parameter('bias', None) if edge_dim is not None: self.lin_edge = Linear(self.edge_dim, self.heads * self.out_channels, bias=False, weight_initializer='glorot') self.e = Parameter( torch.empty(self.heads * self.out_channels, self.heads * self.dim)) else: self.lin_edge = None self.register_parameter('e', None) if num_bases is not None: self.att = Parameter( torch.empty(self.num_relations, self.num_bases)) self.basis = Parameter( torch.empty(self.num_bases, self.in_channels, self.heads * self.out_channels)) elif num_blocks is not None: assert ( self.in_channels % self.num_blocks == 0 and (self.heads * self.out_channels) % self.num_blocks == 0), ( "both 'in_channels' and 'heads * out_channels' must be " "multiple of 'num_blocks' used") self.weight = Parameter( torch.empty(self.num_relations, self.num_blocks, self.in_channels // self.num_blocks, (self.heads * self.out_channels) // self.num_blocks)) else: self.weight = Parameter( torch.empty(self.num_relations, self.in_channels, self.heads * self.out_channels)) self.w = Parameter(torch.ones(self.out_channels)) self.l1 = Parameter(torch.empty(1, self.out_channels)) self.b1 = Parameter(torch.empty(1, self.out_channels)) self.l2 = Parameter(torch.empty(self.out_channels, self.out_channels)) self.b2 = Parameter(torch.empty(1, self.out_channels)) self._alpha = None self.reset_parameters()
[docs] def reset_parameters(self): super().reset_parameters() if self.num_bases is not None: glorot(self.basis) glorot(self.att) else: glorot(self.weight) glorot(self.q) glorot(self.k) zeros(self.bias) ones(self.l1) zeros(self.b1) torch.full(self.l2.size(), 1 / self.out_channels) zeros(self.b2) if self.lin_edge is not None: glorot(self.lin_edge) glorot(self.e)
[docs] def forward( self, x: Tensor, edge_index: Adj, edge_type: OptTensor = None, edge_attr: OptTensor = None, size: Size = None, return_attention_weights=None, ): r"""Runs the forward pass of the module. Args: x (torch.Tensor): The input node features. Can be either a :obj:`[num_nodes, in_channels]` node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). edge_index (torch.Tensor or SparseTensor): The edge indices. edge_type (torch.Tensor, optional): The one-dimensional relation type/index for each edge in :obj:`edge_index`. Should be only :obj:`None` in case :obj:`edge_index` is of type :class:`torch_sparse.SparseTensor` or :class:`torch.sparse.Tensor`. (default: :obj:`None`) edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) size ((int, int), optional): The shape of the adjacency matrix. (default: :obj:`None`) return_attention_weights (bool, optional): If set to :obj:`True`, will additionally return the tuple :obj:`(edge_index, attention_weights)`, holding the computed attention weights for each edge. (default: :obj:`None`) """ # propagate_type: (x: Tensor, edge_type: OptTensor, # edge_attr: OptTensor) out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x, size=size, edge_attr=edge_attr) alpha = self._alpha assert alpha is not None self._alpha = None if isinstance(return_attention_weights, bool): if isinstance(edge_index, Tensor): if is_torch_sparse_tensor(edge_index): # TODO TorchScript requires to return a tuple adj = set_sparse_value(edge_index, alpha) return out, (adj, alpha) else: return out, (edge_index, alpha) elif isinstance(edge_index, SparseTensor): return out, edge_index.set_value(alpha, layout='coo') else: return out
def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: if self.num_bases is not None: # Basis-decomposition ================= w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) w = w.view(self.num_relations, self.in_channels, self.heads * self.out_channels) if self.num_blocks is not None: # Block-diagonal-decomposition ======= if (x_i.dtype == torch.long and x_j.dtype == torch.long and self.num_blocks is not None): raise ValueError('Block-diagonal decomposition not supported ' 'for non-continuous input features.') w = self.weight x_i = x_i.view(-1, 1, w.size(1), w.size(2)) x_j = x_j.view(-1, 1, w.size(1), w.size(2)) w = torch.index_select(w, 0, edge_type) outi = torch.einsum('abcd,acde->ace', x_i, w) outi = outi.contiguous().view(-1, self.heads * self.out_channels) outj = torch.einsum('abcd,acde->ace', x_j, w) outj = outj.contiguous().view(-1, self.heads * self.out_channels) else: # No regularization/Basis-decomposition ======================== if self.num_bases is None: w = self.weight w = torch.index_select(w, 0, edge_type) outi = torch.bmm(x_i.unsqueeze(1), w).squeeze(-2) outj = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2) qi = torch.matmul(outi, self.q) kj = torch.matmul(outj, self.k) alpha_edge, alpha = 0, torch.tensor([0]) if edge_attr is not None: if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) assert self.lin_edge is not None, ( "Please set 'edge_dim = edge_attr.size(-1)' while calling the " "RGATConv layer") edge_attributes = self.lin_edge(edge_attr).view( -1, self.heads * self.out_channels) if edge_attributes.size(0) != edge_attr.size(0): edge_attributes = torch.index_select(edge_attributes, 0, edge_type) alpha_edge = torch.matmul(edge_attributes, self.e) if self.attention_mode == "additive-self-attention": if edge_attr is not None: alpha = torch.add(qi, kj) + alpha_edge else: alpha = torch.add(qi, kj) alpha = F.leaky_relu(alpha, self.negative_slope) elif self.attention_mode == "multiplicative-self-attention": if edge_attr is not None: alpha = (qi * kj) * alpha_edge else: alpha = qi * kj if self.attention_mechanism == "within-relation": across_out = torch.zeros_like(alpha) for r in range(self.num_relations): mask = edge_type == r across_out[mask] = softmax(alpha[mask], index[mask]) alpha = across_out elif self.attention_mechanism == "across-relation": alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha if self.mod == "additive": if self.attention_mode == "additive-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, self.out_channels) * ones.view(-1, self.heads, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1) + h) elif self.attention_mode == "multiplicative-self-attention": ones = torch.ones_like(alpha) h = (outj.view(-1, self.heads, 1, self.out_channels) * ones.view(-1, self.heads, self.dim, 1)) h = torch.mul(self.w, h) return (outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1) + h) elif self.mod == "scaled": if self.attention_mode == "additive-self-attention": ones = alpha.new_ones(index.size()) degree = scatter(ones, index, dim_size=size_i, reduce='sum')[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1), degree.view(-1, 1, self.out_channels)) elif self.attention_mode == "multiplicative-self-attention": ones = alpha.new_ones(index.size()) degree = scatter(ones, index, dim_size=size_i, reduce='sum')[index].unsqueeze(-1) degree = torch.matmul(degree, self.l1) + self.b1 degree = self.activation(degree) degree = torch.matmul(degree, self.l2) + self.b2 return torch.mul( outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1), degree.view(-1, 1, 1, self.out_channels)) elif self.mod == "f-additive": alpha = torch.where(alpha > 0, alpha + 1, alpha) elif self.mod == "f-scaled": ones = alpha.new_ones(index.size()) degree = scatter(ones, index, dim_size=size_i, reduce='sum')[index].unsqueeze(-1) alpha = alpha * degree elif self.training and self.dropout > 0: alpha = F.dropout(alpha, p=self.dropout, training=True) else: alpha = alpha # original if self.attention_mode == "additive-self-attention": return alpha.view(-1, self.heads, 1) * outj.view( -1, self.heads, self.out_channels) else: return (alpha.view(-1, self.heads, self.dim, 1) * outj.view(-1, self.heads, 1, self.out_channels)) def update(self, aggr_out: Tensor) -> Tensor: if self.attention_mode == "additive-self-attention": if self.concat is True: aggr_out = aggr_out.view(-1, self.heads * self.out_channels) else: aggr_out = aggr_out.mean(dim=1) if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out else: if self.concat is True: aggr_out = aggr_out.view( -1, self.heads * self.dim * self.out_channels) else: aggr_out = aggr_out.mean(dim=1) aggr_out = aggr_out.view(-1, self.dim * self.out_channels) if self.bias is not None: aggr_out = aggr_out + self.bias return aggr_out def __repr__(self) -> str: return '{}({}, {}, heads={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads)