torch_geometric.nn.conv.HEATConv
- class HEATConv(in_channels: int, out_channels: int, num_node_types: int, num_edge_types: int, edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, root_weight: bool = True, bias: bool = True, **kwargs)[source]
Bases:
MessagePassing
The heterogeneous edge-enhanced graph attentional operator from the “Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction” paper.
type-specific transformations of nodes of different types
edge type and edge feature incorporation, in which edges are assumed to have different types but contain the same kind of attributes
- Parameters:
in_channels (int) – Size of each input sample, or
-1
to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
num_node_types (int) – The number of node types.
num_edge_types (int) – The number of edge types.
edge_type_emb_dim (int) – The embedding size of edge types.
edge_dim (int) – Edge feature dimensionality.
edge_attr_emb_dim (int) – The embedding size of edge features.
heads (int, optional) – Number of multi-head-attentions. (default:
1
)concat (bool, optional) – If set to
False
, the multi-head attentions are averaged instead of concatenated. (default:True
)negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default:
0.2
)dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default:
0
)root_weight (bool, optional) – If set to
False
, the layer will not add transformed root node features to the output. (default:True
)bias (bool, optional) – If set to
False
, the layer will not learn an additive bias. (default:True
)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing
.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), node types \((|\mathcal{V}|)\), edge types \((|\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)