torch_geometric.nn.models.GraphUNet
- class GraphUNet(in_channels: int, hidden_channels: int, out_channels: int, depth: int, pool_ratios: Union[float, List[float]] = 0.5, sum_res: bool = True, act: Union[str, Callable] = 'relu')[source]
Bases:
Module
The Graph U-Net model from the “Graph U-Nets” paper which implements a U-Net like architecture with graph pooling and unpooling operations.
- Parameters:
in_channels (int) – Size of each input sample.
hidden_channels (int) – Size of each hidden sample.
out_channels (int) – Size of each output sample.
depth (int) – The depth of the U-Net architecture.
pool_ratios (float or [float], optional) – Graph pooling ratio for each depth. (default:
0.5
)sum_res (bool, optional) – If set to
False
, will use concatenation for integration of skip connections instead summation. (default:True
)act (torch.nn.functional, optional) – The nonlinearity to use. (default:
torch.nn.functional.relu
)