Source code for torch_geometric.nn.models.graph_unet

import torch
import torch.nn.functional as F
from torch_sparse import spspmm
from torch_geometric.nn import TopKPooling, GCNConv
from torch_geometric.utils import (add_self_loops, sort_edge_index,
                                   remove_self_loops)
from torch_geometric.utils.repeat import repeat


[docs]class GraphUNet(torch.nn.Module): r"""The Graph U-Net model from the `"Graph U-Nets" <https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like architecture with graph pooling and unpooling operations. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. out_channels (int): Size of each output sample. depth (int): The depth of the U-Net architecture. pool_ratios (float or [float], optional): Graph pooling ratio for each depth. (default: :obj:`0.5`) sum_res (bool, optional): If set to :obj:`False`, will use concatenation for integration of skip connections instead summation. (default: :obj:`True`) act (torch.nn.functional, optional): The nonlinearity to use. (default: :obj:`torch.nn.functional.relu`) """ def __init__(self, in_channels, hidden_channels, out_channels, depth, pool_ratios=0.5, sum_res=True, act=F.relu): super(GraphUNet, self).__init__() assert depth >= 1 self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.depth = depth self.pool_ratios = repeat(pool_ratios, depth) self.act = act self.sum_res = sum_res channels = hidden_channels self.down_convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.down_convs.append(GCNConv(in_channels, channels, improved=True)) for i in range(depth): self.pools.append(TopKPooling(channels, self.pool_ratios[i])) self.down_convs.append(GCNConv(channels, channels, improved=True)) in_channels = channels if sum_res else 2 * channels self.up_convs = torch.nn.ModuleList() for i in range(depth - 1): self.up_convs.append(GCNConv(in_channels, channels, improved=True)) self.up_convs.append(GCNConv(in_channels, out_channels, improved=True)) self.reset_parameters()
[docs] def reset_parameters(self): for conv in self.down_convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() for conv in self.up_convs: conv.reset_parameters()
[docs] def forward(self, x, edge_index, batch=None): """""" if batch is None: batch = edge_index.new_zeros(x.size(0)) edge_weight = x.new_ones(edge_index.size(1)) x = self.down_convs[0](x, edge_index, edge_weight) x = self.act(x) xs = [x] edge_indices = [edge_index] edge_weights = [edge_weight] perms = [] for i in range(1, self.depth + 1): edge_index, edge_weight = self.augment_adj(edge_index, edge_weight, x.size(0)) x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1]( x, edge_index, edge_weight, batch) x = self.down_convs[i](x, edge_index, edge_weight) x = self.act(x) if i < self.depth: xs += [x] edge_indices += [edge_index] edge_weights += [edge_weight] perms += [perm] for i in range(self.depth): j = self.depth - 1 - i res = xs[j] edge_index = edge_indices[j] edge_weight = edge_weights[j] perm = perms[j] up = torch.zeros_like(res) up[perm] = x x = res + up if self.sum_res else torch.cat((res, up), dim=-1) x = self.up_convs[i](x, edge_index, edge_weight) x = self.act(x) if i < self.depth - 1 else x return x
[docs] def augment_adj(self, edge_index, edge_weight, num_nodes): edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=num_nodes) edge_index, edge_weight = sort_edge_index(edge_index, edge_weight, num_nodes) edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index, edge_weight, num_nodes, num_nodes, num_nodes) edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) return edge_index, edge_weight
def __repr__(self): return '{}({}, {}, {}, depth={}, pool_ratios={})'.format( self.__class__.__name__, self.in_channels, self.hidden_channels, self.out_channels, self.depth, self.pool_ratios)