torch_geometric.nn.conv.FastHGTConv

class FastHGTConv(in_channels: Union[int, Dict[str, int]], out_channels: int, metadata: Tuple[List[str], List[Tuple[str, str, str]]], heads: int = 1, **kwargs)[source]

Bases: MessagePassing

See HGTConv.

forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor, SparseTensor]]) Dict[str, Optional[Tensor]][source]

Runs the forward pass of the module.

Parameters
  • x_dict (Dict[str, torch.Tensor]) – A dictionary holding input node features for each individual node type.

  • edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]) – A dictionary holding graph connectivity information for each individual edge type, either as a torch.Tensor of shape [2, num_edges] or a torch_sparse.SparseTensor.

Return type

Dict[str, Optional[torch.Tensor]] - The output node embeddings for each node type. In case a node type does not receive any message, its output will be set to None.

reset_parameters()[source]

Resets all learnable parameters of the module.