torch_geometric.nn.conv.HANConv

class HANConv(in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Tuple[List[str], List[Tuple[str, str, str]]], heads: int = 1, negative_slope=0.2, dropout: float = 0.0, **kwargs)[source]

Bases: MessagePassing

The Heterogenous Graph Attention Operator from the “Heterogenous Graph Attention Network” paper.

Note

For an example of using HANConv, see examples/hetero/han_imdb.py.

Parameters:
  • in_channels (int or Dict[str, int]) – Size of each input sample of every node type, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively. See torch_geometric.data.HeteroData.metadata() for more information.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2)

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor, SparseTensor]], return_semantic_attention_weights: bool = False) Union[Dict[str, Optional[Tensor]], Tuple[Dict[str, Optional[Tensor]], Dict[str, Optional[Tensor]]]][source]

Runs the forward pass of the module.

Parameters:
  • x_dict (Dict[str, torch.Tensor]) – A dictionary holding node feature information for each individual node type.

  • edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a torch.Tensor of shape [2, num_edges] or a torch_sparse.SparseTensor.

  • return_semantic_attention_weights (bool, optional) – If set to True, will additionally return the semantic-level attention weights for each destination node type. (default: False)

reset_parameters()[source]

Resets all learnable parameters of the module.