from typing import Optional
import copy
import math
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch_geometric.nn import inits
[docs]class Linear(torch.nn.Module):
r"""Applies a linear tranformation to the incoming data
.. math::
\mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b}
similar to :class:`torch.nn.Linear`.
It supports lazy initialization and customizable weight and bias
initialization.
Args:
in_channels (int): Size of each input sample.
Will be initialized lazily in case :obj:`-1`.
out_channels (int): Size of each output sample.
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
weight_initializer (str, optional): The initializer for the weight
matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"`
or :obj:`None`).
If set to :obj:`None`, will match default weight initialization of
:class:`torch.nn.Linear`. (default: :obj:`None`)
bias_initializer (str, optional): The initializer for the bias
vector (:obj:`"zeros"` or :obj:`None`).
If set to :obj:`None`, will match default bias initialization of
:class:`torch.nn.Linear`. (default: :obj:`None`)
"""
def __init__(self, in_channels: int, out_channels: int, bias: bool = True,
weight_initializer: Optional[str] = None,
bias_initializer: Optional[str] = None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.weight_initializer = weight_initializer
self.bias_initializer = bias_initializer
if in_channels > 0:
self.weight = Parameter(torch.Tensor(out_channels, in_channels))
else:
self.weight = torch.nn.parameter.UninitializedParameter()
self._hook = self.register_forward_pre_hook(
self.initialize_parameters)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def __deepcopy__(self, memo):
out = Linear(self.in_channels, self.out_channels, self.bias
is not None, self.weight_initializer,
self.bias_initializer)
if self.in_channels > 0:
out.weight = copy.deepcopy(self.weight, memo)
if self.bias is not None:
out.bias = copy.deepcopy(self.bias, memo)
return out
[docs] def reset_parameters(self):
if self.in_channels > 0:
if self.weight_initializer == 'glorot':
inits.glorot(self.weight)
elif self.weight_initializer == 'uniform':
bound = 1.0 / math.sqrt(self.weight.size(-1))
torch.nn.init.uniform_(self.weight.data, -bound, bound)
elif self.weight_initializer == 'kaiming_uniform':
inits.kaiming_uniform(self.weight, fan=self.in_channels,
a=math.sqrt(5))
elif self.weight_initializer is None:
inits.kaiming_uniform(self.weight, fan=self.in_channels,
a=math.sqrt(5))
else:
raise RuntimeError(
f"Linear layer weight initializer "
f"'{self.weight_initializer}' is not supported")
if self.in_channels > 0 and self.bias is not None:
if self.bias_initializer == 'zeros':
inits.zeros(self.bias)
elif self.bias_initializer is None:
inits.uniform(self.in_channels, self.bias)
else:
raise RuntimeError(
f"Linear layer bias initializer "
f"'{self.bias_initializer}' is not supported")
[docs] def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)
[docs] @torch.no_grad()
def initialize_parameters(self, module, input):
if isinstance(self.weight, torch.nn.parameter.UninitializedParameter):
self.in_channels = input[0].size(-1)
self.weight.materialize((self.out_channels, self.in_channels))
self.reset_parameters()
module._hook.remove()
delattr(module, '_hook')
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, bias={self.bias is not None})')