from typing import List, Optional
import torch
from torch import Tensor
from torch.nn import LSTM, Linear
[docs]class JumpingKnowledge(torch.nn.Module):
r"""The Jumping Knowledge layer aggregation module from the
`"Representation Learning on Graphs with Jumping Knowledge Networks"
<https://arxiv.org/abs/1806.03536>`_ paper based on either
**concatenation** (:obj:`"cat"`)
.. math::
\mathbf{x}_v^{(1)} \, \Vert \, \ldots \, \Vert \, \mathbf{x}_v^{(T)}
**max pooling** (:obj:`"max"`)
.. math::
\max \left( \mathbf{x}_v^{(1)}, \ldots, \mathbf{x}_v^{(T)} \right)
or **weighted summation**
.. math::
\sum_{t=1}^T \alpha_v^{(t)} \mathbf{x}_v^{(t)}
with attention scores :math:`\alpha_v^{(t)}` obtained from a bi-directional
LSTM (:obj:`"lstm"`).
Args:
mode (string): The aggregation scheme to use
(:obj:`"cat"`, :obj:`"max"` or :obj:`"lstm"`).
channels (int, optional): The number of channels per representation.
Needs to be only set for LSTM-style aggregation.
(default: :obj:`None`)
num_layers (int, optional): The number of layers to aggregate. Needs to
be only set for LSTM-style aggregation. (default: :obj:`None`)
"""
def __init__(self, mode: str, channels: Optional[int] = None,
num_layers: Optional[int] = None):
super().__init__()
self.mode = mode.lower()
assert self.mode in ['cat', 'max', 'lstm']
if mode == 'lstm':
assert channels is not None, 'channels cannot be None for lstm'
assert num_layers is not None, 'num_layers cannot be None for lstm'
self.lstm = LSTM(channels, (num_layers * channels) // 2,
bidirectional=True, batch_first=True)
self.att = Linear(2 * ((num_layers * channels) // 2), 1)
self.channels = channels
self.num_layers = num_layers
else:
self.lstm = None
self.att = None
self.channels = None
self.num_layers = None
self.reset_parameters()
[docs] def reset_parameters(self):
if self.lstm is not None:
self.lstm.reset_parameters()
if self.att is not None:
self.att.reset_parameters()
[docs] def forward(self, xs: List[Tensor]) -> Tensor:
r"""Aggregates representations across different layers.
Args:
xs (List[Tensor]): List containing layer-wise representations.
"""
if self.mode == 'cat':
return torch.cat(xs, dim=-1)
elif self.mode == 'max':
return torch.stack(xs, dim=-1).max(dim=-1)[0]
else: # self.mode == 'lstm'
assert self.lstm is not None and self.att is not None
x = torch.stack(xs, dim=1) # [num_nodes, num_layers, num_channels]
alpha, _ = self.lstm(x)
alpha = self.att(alpha).squeeze(-1) # [num_nodes, num_layers]
alpha = torch.softmax(alpha, dim=-1)
return (x * alpha.unsqueeze(-1)).sum(dim=1)
def __repr__(self) -> str:
if self.mode == 'lstm':
return (f'{self.__class__.__name__}({self.mode}, '
f'channels={self.channels}, layers={self.num_layers})')
return f'{self.__class__.__name__}({self.mode})'