class GraphUNet(in_channels: int, hidden_channels: int, out_channels: int, depth: int, pool_ratios: Union[float, List[float]] = 0.5, sum_res: bool = True, act: Union[str, Callable] = 'relu')[source]

Bases: Module

The Graph U-Net model from the “Graph U-Nets” paper which implements a U-Net like architecture with graph pooling and unpooling operations.

  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Size of each hidden sample.

  • out_channels (int) – Size of each output sample.

  • depth (int) – The depth of the U-Net architecture.

  • pool_ratios (float or [float], optional) – Graph pooling ratio for each depth. (default: 0.5)

  • sum_res (bool, optional) – If set to False, will use concatenation for integration of skip connections instead summation. (default: True)

  • act (torch.nn.functional, optional) – The nonlinearity to use. (default: torch.nn.functional.relu)

forward(x: Tensor, edge_index: Tensor, batch: Optional[Tensor] = None) Tensor[source]

Resets all learnable parameters of the module.