Source code for torch_geometric.nn.dense.linear

import copy
import math
import sys
import time
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter

import torch_geometric.backend
import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.index import index2ptr
from torch_geometric.nn import inits
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort


def is_uninitialized_parameter(x: Any) -> bool:
    if not hasattr(torch.nn.parameter, 'UninitializedParameter'):
        return False
    return isinstance(x, torch.nn.parameter.UninitializedParameter)


def reset_weight_(weight: Tensor, in_channels: int,
                  initializer: Optional[str] = None) -> Tensor:
    if in_channels <= 0:
        pass
    elif initializer == 'glorot':
        inits.glorot(weight)
    elif initializer == 'uniform':
        bound = 1.0 / math.sqrt(in_channels)
        torch.nn.init.uniform_(weight.data, -bound, bound)
    elif initializer == 'kaiming_uniform':
        inits.kaiming_uniform(weight, fan=in_channels, a=math.sqrt(5))
    elif initializer is None:
        inits.kaiming_uniform(weight, fan=in_channels, a=math.sqrt(5))
    else:
        raise RuntimeError(f"Weight initializer '{initializer}' not supported")

    return weight


def reset_bias_(bias: Optional[Tensor], in_channels: int,
                initializer: Optional[str] = None) -> Optional[Tensor]:
    if bias is None or in_channels <= 0:
        pass
    elif initializer == 'zeros':
        inits.zeros(bias)
    elif initializer is None:
        inits.uniform(in_channels, bias)
    else:
        raise RuntimeError(f"Bias initializer '{initializer}' not supported")

    return bias


[docs]class Linear(torch.nn.Module): r"""Applies a linear transformation to the incoming data. .. math:: \mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b} In contrast to :class:`torch.nn.Linear`, it supports lazy initialization and customizable weight and bias initialization. Args: in_channels (int): Size of each input sample. Will be initialized lazily in case it is given as :obj:`-1`. out_channels (int): Size of each output sample. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) weight_initializer (str, optional): The initializer for the weight matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"` or :obj:`None`). If set to :obj:`None`, will match default weight initialization of :class:`torch.nn.Linear`. (default: :obj:`None`) bias_initializer (str, optional): The initializer for the bias vector (:obj:`"zeros"` or :obj:`None`). If set to :obj:`None`, will match default bias initialization of :class:`torch.nn.Linear`. (default: :obj:`None`) Shapes: - **input:** features :math:`(*, F_{in})` - **output:** features :math:`(*, F_{out})` """ def __init__( self, in_channels: int, out_channels: int, bias: bool = True, weight_initializer: Optional[str] = None, bias_initializer: Optional[str] = None, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.weight_initializer = weight_initializer self.bias_initializer = bias_initializer if in_channels > 0: self.weight = Parameter(torch.empty(out_channels, in_channels)) else: self.weight = torch.nn.parameter.UninitializedParameter() self._hook = self.register_forward_pre_hook( self.initialize_parameters) if bias: self.bias = Parameter(torch.empty(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def __deepcopy__(self, memo): # PyTorch<1.13 cannot handle deep copies of uninitialized parameters :( # TODO Drop this code once PyTorch 1.12 is no longer supported. out = Linear( self.in_channels, self.out_channels, self.bias is not None, self.weight_initializer, self.bias_initializer, ).to(self.weight.device) if self.in_channels > 0: out.weight = copy.deepcopy(self.weight, memo) if self.bias is not None: out.bias = copy.deepcopy(self.bias, memo) return out
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset_weight_(self.weight, self.in_channels, self.weight_initializer) reset_bias_(self.bias, self.in_channels, self.bias_initializer)
[docs] def forward(self, x: Tensor) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input features. """ return F.linear(x, self.weight, self.bias)
@torch.no_grad() def initialize_parameters(self, module, input): if is_uninitialized_parameter(self.weight): self.in_channels = input[0].size(-1) self.weight.materialize((self.out_channels, self.in_channels)) self.reset_parameters() self._hook.remove() delattr(self, '_hook') def _save_to_state_dict(self, destination, prefix, keep_vars): if (is_uninitialized_parameter(self.weight) or torch.onnx.is_in_onnx_export() or keep_vars): destination[prefix + 'weight'] = self.weight else: destination[prefix + 'weight'] = self.weight.detach() if self.bias is not None: if torch.onnx.is_in_onnx_export() or keep_vars: destination[prefix + 'bias'] = self.bias else: destination[prefix + 'bias'] = self.bias.detach() def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): weight = state_dict.get(prefix + 'weight', None) if weight is not None and is_uninitialized_parameter(weight): self.in_channels = -1 self.weight = torch.nn.parameter.UninitializedParameter() if not hasattr(self, '_hook'): self._hook = self.register_forward_pre_hook( self.initialize_parameters) elif weight is not None and is_uninitialized_parameter(self.weight): self.in_channels = weight.size(-1) self.weight.materialize((self.out_channels, self.in_channels)) if hasattr(self, '_hook'): self._hook.remove() delattr(self, '_hook') super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, bias={self.bias is not None})')
[docs]class HeteroLinear(torch.nn.Module): r"""Applies separate linear transformations to the incoming data according to types. For type :math:`\kappa`, it computes .. math:: \mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}. It supports lazy initialization and customizable weight and bias initialization. Args: in_channels (int): Size of each input sample. Will be initialized lazily in case it is given as :obj:`-1`. out_channels (int): Size of each output sample. num_types (int): The number of types. is_sorted (bool, optional): If set to :obj:`True`, assumes that :obj:`type_vec` is sorted. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.Linear`. Shapes: - **input:** features :math:`(*, F_{in})`, type vector :math:`(*)` - **output:** features :math:`(*, F_{out})` """ _timing_cache: Dict[int, Tuple[float, float]] def __init__( self, in_channels: int, out_channels: int, num_types: int, is_sorted: bool = False, **kwargs, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.num_types = num_types self.is_sorted = is_sorted self.kwargs = kwargs if self.in_channels == -1: self.weight = torch.nn.parameter.UninitializedParameter() self._hook = self.register_forward_pre_hook( self.initialize_parameters) else: self.weight = torch.nn.Parameter( torch.empty(num_types, in_channels, out_channels)) if kwargs.get('bias', True): self.bias = Parameter(torch.empty(num_types, out_channels)) else: self.register_parameter('bias', None) # Timing cache for benchmarking naive vs. segment matmul usage: self._timing_cache: Dict[int, Tuple[float, float]] = {} self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" reset_weight_(self.weight, self.in_channels, self.kwargs.get('weight_initializer', None)) reset_bias_(self.bias, self.in_channels, self.kwargs.get('bias_initializer', None))
def forward_naive(self, x: Tensor, type_ptr: Tensor) -> Tensor: out = x.new_empty(x.size(0), self.out_channels) for i, (start, end) in enumerate(zip(type_ptr[:-1], type_ptr[1:])): out[start:end] = x[start:end] @ self.weight[i] return out def forward_segmm(self, x: Tensor, type_ptr: Tensor) -> Tensor: return pyg_lib.ops.segment_matmul(x, type_ptr, self.weight) @torch.no_grad() def _update_timing_cache( self, x: Tensor, type_ptr: Tensor, key: int, ) -> None: MEASURE_ITER = 1 if 'pytest' in sys.modules else 3 if torch.cuda.is_available(): torch.cuda.synchronize() t = time.perf_counter() for _ in range(MEASURE_ITER): _ = self.forward_segmm(x, type_ptr) if torch.cuda.is_available(): torch.cuda.synchronize() time_segmm = time.perf_counter() - t if torch.cuda.is_available(): torch.cuda.synchronize() t = time.perf_counter() for _ in range(MEASURE_ITER): _ = self.forward_naive(x, type_ptr) if torch.cuda.is_available(): torch.cuda.synchronize() time_naive = time.perf_counter() - t self._timing_cache[key] = (time_segmm, time_naive)
[docs] def forward(self, x: Tensor, type_vec: Tensor) -> Tensor: r"""The forward pass. Args: x (torch.Tensor): The input features. type_vec (torch.Tensor): A vector that maps each entry to a type. """ perm: Optional[Tensor] = None if not self.is_sorted and (type_vec[1:] < type_vec[:-1]).any(): type_vec, perm = index_sort(type_vec, self.num_types) x = x[perm] type_ptr = index2ptr(type_vec, self.num_types) if torch_geometric.backend.use_segment_matmul is None: use_segment_matmul = False if (torch_geometric.typing.WITH_SEGMM and not is_compiling() and not torch.jit.is_scripting()): # Use "magnitude" of number of rows as timing key: key = math.floor(math.log10(x.size(0))) if key not in self._timing_cache: self._update_timing_cache(x, type_ptr, key) time_segmm, time_naive = self._timing_cache[key] use_segment_matmul = time_segmm < time_naive else: use_segment_matmul = torch_geometric.backend.use_segment_matmul if (torch_geometric.typing.WITH_SEGMM and not is_compiling() and use_segment_matmul): out = self.forward_segmm(x, type_ptr) else: out = self.forward_naive(x, type_ptr) if self.bias is not None: out += self.bias[type_vec] if perm is not None: # Restore original order (if necessary). out_unsorted = torch.empty_like(out) out_unsorted[perm] = out out = out_unsorted return out
@torch.no_grad() def initialize_parameters(self, module, input): if is_uninitialized_parameter(self.weight): self.in_channels = input[0].size(-1) self.weight.materialize( (self.num_types, self.in_channels, self.out_channels)) self.reset_parameters() self._hook.remove() delattr(self, '_hook') def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_types={self.num_types}, ' f'bias={self.kwargs.get("bias", True)})')
[docs]class HeteroDictLinear(torch.nn.Module): r"""Applies separate linear transformations to the incoming data dictionary. For key :math:`\kappa`, it computes .. math:: \mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}. It supports lazy initialization and customizable weight and bias initialization. Args: in_channels (int or Dict[Any, int]): Size of each input sample. If passed an integer, :obj:`types` will be a mandatory argument. initialized lazily in case it is given as :obj:`-1`. out_channels (int): Size of each output sample. types (List[Any], optional): The keys of the input dictionary. (default: :obj:`None`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.Linear`. """ def __init__( self, in_channels: Union[int, Dict[Any, int]], out_channels: int, types: Optional[Any] = None, **kwargs, ): super().__init__() if isinstance(in_channels, dict): self.types = list(in_channels.keys()) if any([i == -1 for i in in_channels.values()]): self._hook = self.register_forward_pre_hook( self.initialize_parameters) if types is not None and set(self.types) != set(types): raise ValueError("The provided 'types' do not match with the " "keys in the 'in_channels' dictionary") else: if types is None: raise ValueError("Please provide a list of 'types' if passing " "'in_channels' as an integer") if in_channels == -1: self._hook = self.register_forward_pre_hook( self.initialize_parameters) self.types = types in_channels = {node_type: in_channels for node_type in types} self.in_channels = in_channels self.out_channels = out_channels self.kwargs = kwargs self.lins = torch.nn.ModuleDict({ key: Linear(channels, self.out_channels, **kwargs) for key, channels in self.in_channels.items() }) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for lin in self.lins.values(): lin.reset_parameters()
[docs] def forward( self, x_dict: Dict[str, Tensor], ) -> Dict[str, Tensor]: r"""Forward pass. Args: x_dict (Dict[Any, torch.Tensor]): A dictionary holding input features for each individual type. """ out_dict = {} # Only apply fused kernel for more than 10 types, otherwise use # sequential computation (which is generally faster for these cases). use_segment_matmul = torch_geometric.backend.use_segment_matmul if use_segment_matmul is None: use_segment_matmul = len(x_dict) >= 10 if (use_segment_matmul and torch_geometric.typing.WITH_GMM and not is_compiling() and not torch.jit.is_scripting()): xs, weights, biases = [], [], [] for key, lin in self.lins.items(): if key in x_dict: xs.append(x_dict[key]) weights.append(lin.weight.t()) biases.append(lin.bias) biases = None if biases[0] is None else biases outs = pyg_lib.ops.grouped_matmul(xs, weights, biases) for key, out in zip(x_dict.keys(), outs): if key in x_dict: out_dict[key] = out else: for key, lin in self.lins.items(): if key in x_dict: out_dict[key] = lin(x_dict[key]) return out_dict
@torch.no_grad() def initialize_parameters(self, module, input): for key, x in input[0].items(): lin = self.lins[key] if is_uninitialized_parameter(lin.weight): self.lins[key].initialize_parameters(None, x) self.lins[key].reset_parameters() self._hook.remove() self.in_channels = {key: x.size(-1) for key, x in input[0].items()} delattr(self, '_hook') def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, bias={self.kwargs.get("bias", True)})')