Source code for torch_geometric.nn.aggr.patch_transformer

import math
from typing import List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock
from torch_geometric.nn.encoding import PositionalEncoding
from torch_geometric.utils import scatter


[docs]class PatchTransformerAggregation(Aggregation): r"""Performs patch transformer aggregation in which the elements to aggregate are processed by multi-head attention blocks across patches, as described in the `"Simplifying Temporal Heterogeneous Network for Continuous-Time Link Prediction" <https://dl.acm.org/doi/pdf/10.1145/3583780.3615059>`_ paper. Args: in_channels (int): Size of each input sample. out_channels (int): Size of each output sample. patch_size (int): Number of elements in a patch. hidden_channels (int): Intermediate size of each sample. num_transformer_blocks (int, optional): Number of transformer blocks (default: :obj:`1`). heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) dropout (float, optional): Dropout probability of attention weights. (default: :obj:`0.0`) aggr (str or list[str], optional): The aggregation module, *e.g.*, :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, :obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`) """ def __init__( self, in_channels: int, out_channels: int, patch_size: int, hidden_channels: int, num_transformer_blocks: int = 1, heads: int = 1, dropout: float = 0.0, aggr: Union[str, List[str]] = 'mean', ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.patch_size = patch_size self.aggrs = [aggr] if isinstance(aggr, str) else aggr assert len(self.aggrs) > 0 for aggr in self.aggrs: assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std'] self.lin = torch.nn.Linear(in_channels, hidden_channels) self.pad_projector = torch.nn.Linear( patch_size * hidden_channels, hidden_channels, ) self.pe = PositionalEncoding(hidden_channels) self.blocks = torch.nn.ModuleList([ MultiheadAttentionBlock( channels=hidden_channels, heads=heads, layer_norm=True, dropout=dropout, ) for _ in range(num_transformer_blocks) ]) self.fc = torch.nn.Linear( hidden_channels * len(self.aggrs), out_channels, )
[docs] def reset_parameters(self) -> None: self.lin.reset_parameters() self.pad_projector.reset_parameters() self.pe.reset_parameters() for block in self.blocks: block.reset_parameters() self.fc.reset_parameters()
[docs] @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) def forward( self, x: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: if max_num_elements is None: if ptr is not None: count = ptr.diff() else: count = scatter(torch.ones_like(index), index, dim=0, dim_size=dim_size, reduce='sum') max_num_elements = int(count.max()) + 1 # Set `max_num_elements` to a multiple of `patch_size`: max_num_elements = (math.floor(max_num_elements / self.patch_size) * self.patch_size) x = self.lin(x) # TODO If groups are heavily unbalanced, this will create a lot of # "empty" patches. Try to figure out a way to fix this. # [batch_size, num_patches * patch_size, hidden_channels] x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, max_num_elements=max_num_elements) # [batch_size, num_patches, patch_size * hidden_channels] x = x.view(x.size(0), max_num_elements // self.patch_size, self.patch_size * x.size(-1)) # [batch_size, num_patches, hidden_channels] x = self.pad_projector(x) x = x + self.pe(torch.arange(x.size(1), device=x.device)) # [batch_size, num_patches, hidden_channels] for block in self.blocks: x = block(x, x) # [batch_size, hidden_channels] outs: List[Tensor] = [] for aggr in self.aggrs: out = getattr(torch, aggr)(x, dim=1) outs.append(out[0] if isinstance(out, tuple) else out) out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0] # [batch_size, out_channels] return self.fc(out)
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, patch_size={self.patch_size})')