Source code for torch_geometric.graphgym.models.encoder

import torch

import torch_geometric.graphgym.register as register


[docs]class IntegerFeatureEncoder(torch.nn.Module): """ Provides an encoder for integer node features. Args: emb_dim (int): Output embedding dimension num_classes (int): the number of classes for the embedding mapping to learn from """ def __init__(self, emb_dim, num_classes=None): super(IntegerFeatureEncoder, self).__init__() self.encoder = torch.nn.Embedding(num_classes, emb_dim) torch.nn.init.xavier_uniform_(self.encoder.weight.data) def forward(self, batch): """""" # Encode just the first dimension if more exist batch.x = self.encoder(batch.x[:, 0]) return batch
[docs]class AtomEncoder(torch.nn.Module): """ The atom Encoder used in OGB molecule dataset. Args: emb_dim (int): Output embedding dimension num_classes: None """ def __init__(self, emb_dim, num_classes=None): super(AtomEncoder, self).__init__() from ogb.utils.features import get_atom_feature_dims self.atom_embedding_list = torch.nn.ModuleList() for i, dim in enumerate(get_atom_feature_dims()): emb = torch.nn.Embedding(dim, emb_dim) torch.nn.init.xavier_uniform_(emb.weight.data) self.atom_embedding_list.append(emb) def forward(self, batch): """""" encoded_features = 0 for i in range(batch.x.shape[1]): encoded_features += self.atom_embedding_list[i](batch.x[:, i]) batch.x = encoded_features return batch
[docs]class BondEncoder(torch.nn.Module): """ The bond Encoder used in OGB molecule dataset. Args: emb_dim (int): Output edge embedding dimension """ def __init__(self, emb_dim): super(BondEncoder, self).__init__() from ogb.utils.features import get_bond_feature_dims self.bond_embedding_list = torch.nn.ModuleList() for i, dim in enumerate(get_bond_feature_dims()): emb = torch.nn.Embedding(dim, emb_dim) torch.nn.init.xavier_uniform_(emb.weight.data) self.bond_embedding_list.append(emb) def forward(self, batch): """""" bond_embedding = 0 for i in range(batch.edge_attr.shape[1]): edge_attr = batch.edge_attr bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) batch.edge_attr = bond_embedding return batch
node_encoder_dict = {'Integer': IntegerFeatureEncoder, 'Atom': AtomEncoder} register.node_encoder_dict = { **register.node_encoder_dict, **node_encoder_dict } edge_encoder_dict = {'Bond': BondEncoder} register.edge_encoder_dict = { **register.edge_encoder_dict, **edge_encoder_dict }