Source code for torch_geometric.nn.aggr.base

from typing import Final, Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.utils import scatter, segment, to_dense_batch


[docs]class Aggregation(torch.nn.Module): r"""An abstract base class for implementing custom aggregations. Aggregation can be either performed via an :obj:`index` vector, which defines the mapping from input elements to their location in the output: | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ master/docs/source/_figures/add.svg?sanitize=true :align: center :width: 400px | Notably, :obj:`index` does not have to be sorted (for most aggregation operators): .. code-block:: python # Feature matrix holding 10 elements with 64 features each: x = torch.randn(10, 64) # Assign each element to one of three sets: index = torch.tensor([0, 0, 1, 0, 2, 0, 2, 1, 0, 2]) output = aggr(x, index) # Output shape: [3, 64] Alternatively, aggregation can be achieved via a "compressed" index vector called :obj:`ptr`. Here, elements within the same set need to be grouped together in the input, and :obj:`ptr` defines their boundaries: .. code-block:: python # Feature matrix holding 10 elements with 64 features each: x = torch.randn(10, 64) # Define the boundary indices for three sets: ptr = torch.tensor([0, 4, 7, 10]) output = aggr(x, ptr=ptr) # Output shape: [3, 64] Note that at least one of :obj:`index` or :obj:`ptr` must be defined. Shapes: - **input:** node features :math:`(*, |\mathcal{V}|, F_{in})` or edge features :math:`(*, |\mathcal{E}|, F_{in})`, index vector :math:`(|\mathcal{V}|)` or :math:`(|\mathcal{E}|)`, - **output:** graph features :math:`(*, |\mathcal{G}|, F_{out})` or node features :math:`(*, |\mathcal{V}|, F_{out})` """ def __init__(self) -> None: super().__init__() self._deterministic: Final[bool] = ( torch.are_deterministic_algorithms_enabled() or torch.is_deterministic_algorithms_warn_only_enabled())
[docs] def forward( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None, ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The source tensor. index (torch.Tensor, optional): The indices of elements for applying the aggregation. One of :obj:`index` or :obj:`ptr` must be defined. (default: :obj:`None`) ptr (torch.Tensor, optional): If given, computes the aggregation based on sorted inputs in CSR representation. One of :obj:`index` or :obj:`ptr` must be defined. (default: :obj:`None`) dim_size (int, optional): The size of the output tensor at dimension :obj:`dim` after aggregation. (default: :obj:`None`) dim (int, optional): The dimension in which to aggregate. (default: :obj:`-2`) max_num_elements: (int, optional): The maximum number of elements within a single aggregation group. (default: :obj:`None`) """ pass
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" pass
@disable_dynamic_shapes(required_args=['dim_size']) def __call__( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, **kwargs, ) -> Tensor: if dim >= x.dim() or dim < -x.dim(): raise ValueError(f"Encountered invalid dimension '{dim}' of " f"source tensor with {x.dim()} dimensions") if index is None and ptr is None: index = x.new_zeros(x.size(dim), dtype=torch.long) if ptr is not None: if dim_size is None: dim_size = ptr.numel() - 1 elif dim_size != ptr.numel() - 1: raise ValueError(f"Encountered invalid 'dim_size' (got " f"'{dim_size}' but expected " f"'{ptr.numel() - 1}')") if index is not None and dim_size is None: dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 try: return super().__call__(x, index=index, ptr=ptr, dim_size=dim_size, dim=dim, **kwargs) except (IndexError, RuntimeError) as e: if index is not None: if index.numel() > 0 and dim_size <= int(index.max()): raise ValueError(f"Encountered invalid 'dim_size' (got " f"'{dim_size}' but expected " f">= '{int(index.max()) + 1}')") raise e def __repr__(self) -> str: return f'{self.__class__.__name__}()' # Assertions ############################################################## def assert_index_present(self, index: Optional[Tensor]): # TODO Currently, not all aggregators support `ptr`. This assert helps # to ensure that we require `index` to be passed to the computation: if index is None: raise NotImplementedError( "Aggregation requires 'index' to be specified") def assert_sorted_index(self, index: Optional[Tensor]): if index is not None and not torch.all(index[:-1] <= index[1:]): raise ValueError("Can not perform aggregation since the 'index' " "tensor is not sorted. Specifically, if you use " "this aggregation as part of 'MessagePassing`, " "ensure that 'edge_index' is sorted by " "destination nodes, e.g., by calling " "`data.sort(sort_by_row=False)`") def assert_two_dimensional_input(self, x: Tensor, dim: int): if x.dim() != 2: raise ValueError(f"Aggregation requires two-dimensional inputs " f"(got '{x.dim()}')") if dim not in [-2, 0]: raise ValueError(f"Aggregation needs to perform aggregation in " f"first dimension (got '{dim}')") # Helper methods ########################################################## def reduce(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, reduce: str = 'sum') -> Tensor: if ptr is not None: if index is None or self._deterministic: ptr = expand_left(ptr, dim, dims=x.dim()) return segment(x, ptr, reduce=reduce) if index is None: raise RuntimeError("Aggregation requires 'index' to be specified") return scatter(x, index, dim, dim_size, reduce) def to_dense_batch( self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, fill_value: float = 0.0, max_num_elements: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: # TODO Currently, `to_dense_batch` can only operate on `index`: self.assert_index_present(index) self.assert_sorted_index(index) self.assert_two_dimensional_input(x, dim) return to_dense_batch( x, index, batch_size=dim_size, fill_value=fill_value, max_num_nodes=max_num_elements, )
############################################################################### def expand_left(ptr: Tensor, dim: int, dims: int) -> Tensor: for _ in range(dims + dim if dim < 0 else dim): ptr = ptr.unsqueeze(0) return ptr