from typing import Callable, List, Union
import torch
from torch import Tensor
from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import OptTensor, PairTensor
from torch_geometric.utils import (
add_self_loops,
remove_self_loops,
to_torch_csr_tensor,
)
from torch_geometric.utils.repeat import repeat
[docs]class GraphUNet(torch.nn.Module):
r"""The Graph U-Net model from the `"Graph U-Nets"
<https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like
architecture with graph pooling and unpooling operations.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
out_channels (int): Size of each output sample.
depth (int): The depth of the U-Net architecture.
pool_ratios (float or [float], optional): Graph pooling ratio for each
depth. (default: :obj:`0.5`)
sum_res (bool, optional): If set to :obj:`False`, will use
concatenation for integration of skip connections instead
summation. (default: :obj:`True`)
act (torch.nn.functional, optional): The nonlinearity to use.
(default: :obj:`torch.nn.functional.relu`)
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
depth: int,
pool_ratios: Union[float, List[float]] = 0.5,
sum_res: bool = True,
act: Union[str, Callable] = 'relu',
):
super().__init__()
assert depth >= 1
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.depth = depth
self.pool_ratios = repeat(pool_ratios, depth)
self.act = activation_resolver(act)
self.sum_res = sum_res
channels = hidden_channels
self.down_convs = torch.nn.ModuleList()
self.pools = torch.nn.ModuleList()
self.down_convs.append(GCNConv(in_channels, channels, improved=True))
for i in range(depth):
self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
self.down_convs.append(GCNConv(channels, channels, improved=True))
in_channels = channels if sum_res else 2 * channels
self.up_convs = torch.nn.ModuleList()
for i in range(depth - 1):
self.up_convs.append(GCNConv(in_channels, channels, improved=True))
self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))
self.reset_parameters()
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for conv in self.down_convs:
conv.reset_parameters()
for pool in self.pools:
pool.reset_parameters()
for conv in self.up_convs:
conv.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Tensor,
batch: OptTensor = None) -> Tensor:
"""""" # noqa: D419
if batch is None:
batch = edge_index.new_zeros(x.size(0))
edge_weight = x.new_ones(edge_index.size(1))
x = self.down_convs[0](x, edge_index, edge_weight)
x = self.act(x)
xs = [x]
edge_indices = [edge_index]
edge_weights = [edge_weight]
perms = []
for i in range(1, self.depth + 1):
edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,
x.size(0))
x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](
x, edge_index, edge_weight, batch)
x = self.down_convs[i](x, edge_index, edge_weight)
x = self.act(x)
if i < self.depth:
xs += [x]
edge_indices += [edge_index]
edge_weights += [edge_weight]
perms += [perm]
for i in range(self.depth):
j = self.depth - 1 - i
res = xs[j]
edge_index = edge_indices[j]
edge_weight = edge_weights[j]
perm = perms[j]
up = torch.zeros_like(res)
up[perm] = x
x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
x = self.up_convs[i](x, edge_index, edge_weight)
x = self.act(x) if i < self.depth - 1 else x
return x
def augment_adj(self, edge_index: Tensor, edge_weight: Tensor,
num_nodes: int) -> PairTensor:
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
num_nodes=num_nodes)
adj = to_torch_csr_tensor(edge_index, edge_weight,
size=(num_nodes, num_nodes))
adj = (adj @ adj).to_sparse_coo()
edge_index, edge_weight = adj.indices(), adj.values()
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
return edge_index, edge_weight
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.hidden_channels}, {self.out_channels}, '
f'depth={self.depth}, pool_ratios={self.pool_ratios})')