Source code for torch_geometric.nn.models.rect

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear
from torch_scatter import scatter

from torch_geometric.nn import GCNConv
from torch_geometric.typing import Adj, OptTensor

[docs]class RECT_L(torch.nn.Module): r"""The RECT model, *i.e.* its supervised RECT-L part, from the `"Network Embedding with Completely-imbalanced Labels" <>`_ paper. In particular, a GCN model is trained that reconstructs semantic class knowledge. .. note:: For an example of using RECT, see `examples/ <>`_. Args: in_channels (int): Size of each input sample. hidden_channels (int): Intermediate size of each sample. normalize (bool, optional): Whether to add self-loops and compute symmetric normalization coefficients on the fly. (default: :obj:`True`) dropout (float, optional): The dropout probability. (default: :obj:`0.0`) """ def __init__(self, in_channels: int, hidden_channels: int, normalize: bool = True, dropout: float = 0.0): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.dropout = dropout self.conv = GCNConv(in_channels, hidden_channels, normalize=normalize) self.lin = Linear(hidden_channels, in_channels) self.reset_parameters()
[docs] def reset_parameters(self): self.conv.reset_parameters() self.lin.reset_parameters() torch.nn.init.xavier_uniform_(
[docs] def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" x = self.conv(x, edge_index, edge_weight) x = F.dropout(x, p=self.dropout, return self.lin(x)
[docs] @torch.no_grad() def embed(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: return self.conv(x, edge_index, edge_weight)
[docs] @torch.no_grad() def get_semantic_labels(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor: """Replaces the original labels by their class-centers.""" y = y[mask] mean = scatter(x[mask], y, dim=0, reduce='mean') return mean[y]
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.hidden_channels})')