Source code for torch_geometric.nn.encoding

import torch
from torch import Tensor


[docs]class PositionalEncoding(torch.nn.Module): r"""The positional encoding scheme from `"Attention Is All You Need" <https://arxiv.org/pdf/1706.03762.pdf>`_ paper .. math:: PE(x)_{2 \cdot i} &= \sin(x / 10000^{2 \cdot i / d}) PE(x)_{2 \cdot i + 1} &= \cos(x / 10000^{2 \cdot i / d}) where :math:`x` is the position and :math:`i` is the dimension. Args: out_channels (int): Size :math:`d` of each output sample. base_freq (float, optional): The base frequency of sinusoidal functions. (default: :obj:`1e-4`) granularity (float, optional): The granularity of the positions. If set to smaller value, the encoder will capture more fine-grained changes in positions. (default: :obj:`1.0`) """ def __init__( self, out_channels: int, base_freq: float = 1e-4, granularity: float = 1.0, ): super().__init__() if out_channels % 2 != 0: raise ValueError(f"Cannot use sinusoidal positional encoding with " f"odd 'out_channels' (got {out_channels}).") self.out_channels = out_channels self.base_freq = base_freq self.granularity = granularity frequency = torch.logspace(0, 1, out_channels // 2, base_freq) self.register_buffer('frequency', frequency)
[docs] def forward(self, x: Tensor) -> Tensor: """""" x = x / self.granularity if self.granularity != 1.0 else x out = x.view(-1, 1) * self.frequency.view(1, -1) return torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
def __repr__(self) -> str: return f'{self.__class__.__name__}({self.out_channels})'