torch_geometric.nn.conv.HEATConv

class HEATConv(in_channels: int, out_channels: int, num_node_types: int, num_edge_types: int, edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, root_weight: bool = True, bias: bool = True, **kwargs)[source]

Bases: MessagePassing

The heterogeneous edge-enhanced graph attentional operator from the “Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction” paper.

HEATConv enhances GATConv by:

  1. type-specific transformations of nodes of different types

  2. edge type and edge feature incorporation, in which edges are assumed to have different types but contain the same kind of attributes

Parameters:
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • num_node_types (int) – The number of node types.

  • num_edge_types (int) – The number of edge types.

  • edge_type_emb_dim (int) – The embedding size of edge types.

  • edge_dim (int) – Edge feature dimensionality.

  • edge_attr_emb_dim (int) – The embedding size of edge features.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • concat (bool, optional) – If set to False, the multi-head attentions are averaged instead of concatenated. (default: True)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • root_weight (bool, optional) – If set to False, the layer will not add transformed root node features to the output. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), node types \((|\mathcal{V}|)\), edge types \((|\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], node_type: Tensor, edge_type: Tensor, edge_attr: Optional[Tensor] = None) Tensor[source]

Runs the forward pass of the module.

reset_parameters()[source]

Resets all learnable parameters of the module.