torch_geometric.nn.aggr.DegreeScalerAggregation

class DegreeScalerAggregation(aggr: Union[str, List[str], Aggregation], scaler: Union[str, List[str]], deg: Tensor, train_norm: bool = False, aggr_kwargs: Optional[List[Dict[str, Any]]] = None)[source]

Bases: Aggregation

Combines one or more aggregators and transforms its output with one or more scalers as introduced in the “Principal Neighbourhood Aggregation for Graph Nets” paper. The scalers are normalised by the in-degree of the training set and so must be provided at time of construction. See torch_geometric.nn.conv.PNAConv for more information.

Parameters:
  • aggr (str or [str] or Aggregation) – The aggregation scheme to use. See MessagePassing for more information.

  • scaler (str or list) – Set of scaling function identifiers, namely one or more of "identity", "amplification", "attenuation", "linear" and "inverse_linear".

  • deg (Tensor) – Histogram of in-degrees of nodes in the training set, used by scalers to normalize.

  • train_norm (bool, optional) – Whether normalization parameters are trainable. (default: False)

  • aggr_kwargs (Dict[str, Any], optional) – Arguments passed to the respective aggregation function in case it gets automatically resolved. (default: None)

reset_parameters()[source]

Resets all learnable parameters of the module.

forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor[source]

Forward pass.

Parameters:
  • x (torch.Tensor) – The source tensor.

  • index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of index or ptr must be defined. (default: None)

  • ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of index or ptr must be defined. (default: None)

  • dim_size (int, optional) – The size of the output tensor at dimension dim after aggregation. (default: None)

  • dim (int, optional) – The dimension in which to aggregate. (default: -2)

  • max_num_elements – (int, optional): The maximum number of elements within a single aggregation group. (default: None)