import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.inits import kaiming_uniform, uniform
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor
class Linear(torch.nn.Module):
def __init__(self, in_channels, out_channels, groups=1, bias=True):
super().__init__()
assert in_channels % groups == 0 and out_channels % groups == 0
self.in_channels = in_channels
self.out_channels = out_channels
self.groups = groups
self.weight = Parameter(
torch.empty(groups, in_channels // groups, out_channels // groups))
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
kaiming_uniform(self.weight, fan=self.weight.size(1), a=math.sqrt(5))
uniform(self.weight.size(1), self.bias)
def forward(self, src):
# Input: [*, in_channels]
# Output: [*, out_channels]
if self.groups > 1:
size = src.size()[:-1]
src = src.view(-1, self.groups, self.in_channels // self.groups)
src = src.transpose(0, 1).contiguous()
out = torch.matmul(src, self.weight)
out = out.transpose(1, 0).contiguous()
out = out.view(size + (self.out_channels, ))
else:
out = torch.matmul(src, self.weight.squeeze(0))
if self.bias is not None:
out = out + self.bias
return out
def __repr__(self) -> str: # pragma: no cover
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, groups={self.groups})')
def restricted_softmax(src, dim: int = -1, margin: float = 0.):
src_max = torch.clamp(src.max(dim=dim, keepdim=True)[0], min=0.)
out = (src - src_max).exp()
out = out / (out.sum(dim=dim, keepdim=True) + (margin - src_max).exp())
return out
class Attention(torch.nn.Module):
def __init__(self, dropout=0):
super().__init__()
self.dropout = dropout
def forward(self, query, key, value):
return self.compute_attention(query, key, value)
def compute_attention(self, query, key, value):
# query: [*, query_entries, dim_k]
# key: [*, key_entries, dim_k]
# value: [*, key_entries, dim_v]
# Output: [*, query_entries, dim_v]
assert query.dim() == key.dim() == value.dim() >= 2
assert query.size(-1) == key.size(-1)
assert key.size(-2) == value.size(-2)
# Score: [*, query_entries, key_entries]
score = torch.matmul(query, key.transpose(-2, -1))
score = score / math.sqrt(key.size(-1))
score = restricted_softmax(score, dim=-1)
score = F.dropout(score, p=self.dropout, training=self.training)
return torch.matmul(score, value)
def __repr__(self) -> str: # pragma: no cover
return f'{self.__class__.__name__}(dropout={self.dropout})'
class MultiHead(Attention):
def __init__(self, in_channels, out_channels, heads=1, groups=1, dropout=0,
bias=True):
super().__init__(dropout)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.groups = groups
self.bias = bias
assert in_channels % heads == 0 and out_channels % heads == 0
assert in_channels % groups == 0 and out_channels % groups == 0
assert max(groups, self.heads) % min(groups, self.heads) == 0
self.lin_q = Linear(in_channels, out_channels, groups, bias)
self.lin_k = Linear(in_channels, out_channels, groups, bias)
self.lin_v = Linear(in_channels, out_channels, groups, bias)
self.reset_parameters()
def reset_parameters(self):
self.lin_q.reset_parameters()
self.lin_k.reset_parameters()
self.lin_v.reset_parameters()
def forward(self, query, key, value):
# query: [*, query_entries, in_channels]
# key: [*, key_entries, in_channels]
# value: [*, key_entries, in_channels]
# Output: [*, query_entries, out_channels]
assert query.dim() == key.dim() == value.dim() >= 2
assert query.size(-1) == key.size(-1) == value.size(-1)
assert key.size(-2) == value.size(-2)
query = self.lin_q(query)
key = self.lin_k(key)
value = self.lin_v(value)
# query: [*, heads, query_entries, out_channels // heads]
# key: [*, heads, key_entries, out_channels // heads]
# value: [*, heads, key_entries, out_channels // heads]
size = query.size()[:-2]
out_channels_per_head = self.out_channels // self.heads
query_size = size + (query.size(-2), self.heads, out_channels_per_head)
query = query.view(query_size).transpose(-2, -3)
key_size = size + (key.size(-2), self.heads, out_channels_per_head)
key = key.view(key_size).transpose(-2, -3)
value_size = size + (value.size(-2), self.heads, out_channels_per_head)
value = value.view(value_size).transpose(-2, -3)
# Output: [*, heads, query_entries, out_channels // heads]
out = self.compute_attention(query, key, value)
# Output: [*, query_entries, heads, out_channels // heads]
out = out.transpose(-3, -2).contiguous()
# Output: [*, query_entries, out_channels]
out = out.view(size + (query.size(-2), self.out_channels))
return out
def __repr__(self) -> str: # pragma: no cover
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, heads={self.heads}, '
f'groups={self.groups}, dropout={self.droput}, '
f'bias={self.bias})')
[docs]class DNAConv(MessagePassing):
r"""The dynamic neighborhood aggregation operator from the `"Just Jump:
Towards Dynamic Neighborhood Aggregation in Graph Neural Networks"
<https://arxiv.org/abs/1904.04849>`_ paper.
.. math::
\mathbf{x}_v^{(t)} = h_{\mathbf{\Theta}}^{(t)} \left( \mathbf{x}_{v
\leftarrow v}^{(t)}, \left\{ \mathbf{x}_{v \leftarrow w}^{(t)} : w \in
\mathcal{N}(v) \right\} \right)
based on (multi-head) dot-product attention
.. math::
\mathbf{x}_{v \leftarrow w}^{(t)} = \textrm{Attention} \left(
\mathbf{x}^{(t-1)}_v \, \mathbf{\Theta}_Q^{(t)}, [\mathbf{x}_w^{(1)},
\ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_K^{(t)}, \,
[\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \,
\mathbf{\Theta}_V^{(t)} \right)
with :math:`\mathbf{\Theta}_Q^{(t)}, \mathbf{\Theta}_K^{(t)},
\mathbf{\Theta}_V^{(t)}` denoting (grouped) projection matrices for query,
key and value information, respectively.
:math:`h^{(t)}_{\mathbf{\Theta}}` is implemented as a non-trainable
version of :class:`torch_geometric.nn.conv.GCNConv`.
.. note::
In contrast to other layers, this operator expects node features as
shape :obj:`[num_nodes, num_layers, channels]`.
Args:
channels (int): Size of each input/output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
groups (int, optional): Number of groups to use for all linear
projections. (default: :obj:`1`)
dropout (float, optional): Dropout probability of attention
coefficients. (default: :obj:`0.`)
cached (bool, optional): If set to :obj:`True`, the layer will cache
the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
cached version for further executions.
This parameter should only be set to :obj:`True` in transductive
learning scenarios. (default: :obj:`False`)
normalize (bool, optional): Whether to add self-loops and apply
symmetric normalization. (default: :obj:`True`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, L, F)` where :math:`L` is the
number of layers,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F)`
"""
_cached_edge_index: Optional[OptPairTensor]
_cached_adj_t: Optional[SparseTensor]
def __init__(self, channels: int, heads: int = 1, groups: int = 1,
dropout: float = 0., cached: bool = False,
normalize: bool = True, add_self_loops: bool = True,
bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super().__init__(node_dim=0, **kwargs)
self.bias = bias
self.cached = cached
self.normalize = normalize
self.add_self_loops = add_self_loops
self._cached_edge_index = None
self._cached_adj_t = None
self.multi_head = MultiHead(channels, channels, heads, groups, dropout,
bias)
self.reset_parameters()
[docs] def reset_parameters(self):
super().reset_parameters()
self.multi_head.reset_parameters()
self._cached_edge_index = None
self._cached_adj_t = None
[docs] def forward(
self,
x: Tensor,
edge_index: Adj,
edge_weight: OptTensor = None,
) -> Tensor:
r"""Runs the forward pass of the module.
Args:
x (torch.Tensor): The input node features of shape
:obj:`[num_nodes, num_layers, channels]`.
edge_index (torch.Tensor or SparseTensor): The edge indices.
edge_weight (torch.Tensor, optional): The edge weights.
(default: :obj:`None`)
"""
if x.dim() != 3:
raise ValueError('Feature shape must be [num_nodes, num_layers, '
'channels].')
if self.normalize:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if cache is None:
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim), False,
self.add_self_loops, self.flow, dtype=x.dtype)
if self.cached:
self._cached_edge_index = (edge_index, edge_weight)
else:
edge_index, edge_weight = cache[0], cache[1]
elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if cache is None:
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim), False,
self.add_self_loops, self.flow, dtype=x.dtype)
if self.cached:
self._cached_adj_t = edge_index
else:
edge_index = cache
# propagate_type: (x: Tensor, edge_weight: OptTensor)
return self.propagate(edge_index, x=x, edge_weight=edge_weight)
def message(self, x_i: Tensor, x_j: Tensor, edge_weight: Tensor) -> Tensor:
x_i = x_i[:, -1:] # [num_edges, 1, channels]
out = self.multi_head(x_i, x_j, x_j) # [num_edges, 1, channels]
return edge_weight.view(-1, 1) * out.squeeze(1)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.multi_head.in_channels}, '
f'heads={self.multi_head.heads}, '
f'groups={self.multi_head.groups})')