Source code for torch_geometric.nn.models.graph_unet

from typing import Callable, List, Union

import torch
from torch import Tensor

from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import OptTensor, PairTensor
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    to_torch_csr_tensor,
)
from torch_geometric.utils.repeat import repeat


[docs]class GraphUNet(torch.nn.Module): r"""The Graph U-Net model from the `"Graph U-Nets" <https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like architecture with graph pooling and unpooling operations. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. out_channels (int): Size of each output sample. depth (int): The depth of the U-Net architecture. pool_ratios (float or [float], optional): Graph pooling ratio for each depth. (default: :obj:`0.5`) sum_res (bool, optional): If set to :obj:`False`, will use concatenation for integration of skip connections instead summation. (default: :obj:`True`) act (torch.nn.functional, optional): The nonlinearity to use. (default: :obj:`torch.nn.functional.relu`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, depth: int, pool_ratios: Union[float, List[float]] = 0.5, sum_res: bool = True, act: Union[str, Callable] = 'relu', ): super().__init__() assert depth >= 1 self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.depth = depth self.pool_ratios = repeat(pool_ratios, depth) self.act = activation_resolver(act) self.sum_res = sum_res channels = hidden_channels self.down_convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.down_convs.append(GCNConv(in_channels, channels, improved=True)) for i in range(depth): self.pools.append(TopKPooling(channels, self.pool_ratios[i])) self.down_convs.append(GCNConv(channels, channels, improved=True)) in_channels = channels if sum_res else 2 * channels self.up_convs = torch.nn.ModuleList() for i in range(depth - 1): self.up_convs.append(GCNConv(in_channels, channels, improved=True)) self.up_convs.append(GCNConv(in_channels, out_channels, improved=True)) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for conv in self.down_convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() for conv in self.up_convs: conv.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Tensor, batch: OptTensor = None) -> Tensor: """""" # noqa: D419 if batch is None: batch = edge_index.new_zeros(x.size(0)) edge_weight = x.new_ones(edge_index.size(1)) x = self.down_convs[0](x, edge_index, edge_weight) x = self.act(x) xs = [x] edge_indices = [edge_index] edge_weights = [edge_weight] perms = [] for i in range(1, self.depth + 1): edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0)) x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1]( x, edge_index, edge_weight, batch) x = self.down_convs[i](x, edge_index, edge_weight) x = self.act(x) if i < self.depth: xs += [x] edge_indices += [edge_index] edge_weights += [edge_weight] perms += [perm] for i in range(self.depth): j = self.depth - 1 - i res = xs[j] edge_index = edge_indices[j] edge_weight = edge_weights[j] perm = perms[j] up = torch.zeros_like(res) up[perm] = x x = res + up if self.sum_res else torch.cat((res, up), dim=-1) x = self.up_convs[i](x, edge_index, edge_weight) x = self.act(x) if i < self.depth - 1 else x return x
def augment_adj(self, edge_index: Tensor, edge_weight: Tensor, num_nodes: int) -> PairTensor: edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=num_nodes) adj = to_torch_csr_tensor(edge_index, edge_weight, size=(num_nodes, num_nodes)) adj = (adj @ adj).to_sparse_coo() edge_index, edge_weight = adj.indices(), adj.values() edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) return edge_index, edge_weight def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.hidden_channels}, {self.out_channels}, ' f'depth={self.depth}, pool_ratios={self.pool_ratios})')