Source code for torch_geometric.nn.models.dimenet

import os
import os.path as osp
from math import pi as PI
from math import sqrt
from typing import Callable, Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch.nn import Embedding, Linear
from torch_scatter import scatter
from torch_sparse import SparseTensor

from torch_geometric.data import Dataset, download_url
from torch_geometric.data.makedirs import makedirs
from torch_geometric.nn import radius_graph
from torch_geometric.nn.inits import glorot_orthogonal
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import OptTensor

qm9_target_dict = {
    0: 'mu',
    1: 'alpha',
    2: 'homo',
    3: 'lumo',
    5: 'r2',
    6: 'zpve',
    7: 'U0',
    8: 'U',
    9: 'H',
    10: 'G',
    11: 'Cv',
}


class Envelope(torch.nn.Module):
    def __init__(self, exponent: int):
        super().__init__()
        self.p = exponent + 1
        self.a = -(self.p + 1) * (self.p + 2) / 2
        self.b = self.p * (self.p + 2)
        self.c = -self.p * (self.p + 1) / 2

    def forward(self, x: Tensor) -> Tensor:
        p, a, b, c = self.p, self.a, self.b, self.c
        x_pow_p0 = x.pow(p - 1)
        x_pow_p1 = x_pow_p0 * x
        x_pow_p2 = x_pow_p1 * x
        return (1. / x + a * x_pow_p0 + b * x_pow_p1 +
                c * x_pow_p2) * (x < 1.0).to(x.dtype)


class BesselBasisLayer(torch.nn.Module):
    def __init__(self, num_radial: int, cutoff: float = 5.0,
                 envelope_exponent: int = 5):
        super().__init__()
        self.cutoff = cutoff
        self.envelope = Envelope(envelope_exponent)

        self.freq = torch.nn.Parameter(torch.Tensor(num_radial))

        self.reset_parameters()

    def reset_parameters(self):
        with torch.no_grad():
            torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
        self.freq.requires_grad_()

    def forward(self, dist: Tensor) -> Tensor:
        dist = (dist.unsqueeze(-1) / self.cutoff)
        return self.envelope(dist) * (self.freq * dist).sin()


class SphericalBasisLayer(torch.nn.Module):
    def __init__(self, num_spherical: int, num_radial: int,
                 cutoff: float = 5.0, envelope_exponent: int = 5):
        super().__init__()
        import sympy as sym

        from torch_geometric.nn.models.dimenet_utils import (
            bessel_basis,
            real_sph_harm,
        )

        assert num_radial <= 64
        self.num_spherical = num_spherical
        self.num_radial = num_radial
        self.cutoff = cutoff
        self.envelope = Envelope(envelope_exponent)

        bessel_forms = bessel_basis(num_spherical, num_radial)
        sph_harm_forms = real_sph_harm(num_spherical)
        self.sph_funcs = []
        self.bessel_funcs = []

        x, theta = sym.symbols('x theta')
        modules = {'sin': torch.sin, 'cos': torch.cos}
        for i in range(num_spherical):
            if i == 0:
                sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0)
                self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1)
            else:
                sph = sym.lambdify([theta], sph_harm_forms[i][0], modules)
                self.sph_funcs.append(sph)
            for j in range(num_radial):
                bessel = sym.lambdify([x], bessel_forms[i][j], modules)
                self.bessel_funcs.append(bessel)

    def forward(self, dist: Tensor, angle: Tensor, idx_kj: Tensor) -> Tensor:
        dist = dist / self.cutoff
        rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)
        rbf = self.envelope(dist).unsqueeze(-1) * rbf

        cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1)

        n, k = self.num_spherical, self.num_radial
        out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k)
        return out


class EmbeddingBlock(torch.nn.Module):
    def __init__(self, num_radial: int, hidden_channels: int, act: Callable):
        super().__init__()
        self.act = act

        self.emb = Embedding(95, hidden_channels)
        self.lin_rbf = Linear(num_radial, hidden_channels)
        self.lin = Linear(3 * hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
        self.lin_rbf.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor:
        x = self.emb(x)
        rbf = self.act(self.lin_rbf(rbf))
        return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))


class ResidualLayer(torch.nn.Module):
    def __init__(self, hidden_channels: int, act: Callable):
        super().__init__()
        self.act = act
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin1.weight, scale=2.0)
        self.lin1.bias.data.fill_(0)
        glorot_orthogonal(self.lin2.weight, scale=2.0)
        self.lin2.bias.data.fill_(0)

    def forward(self, x: Tensor) -> Tensor:
        return x + self.act(self.lin2(self.act(self.lin1(x))))


class InteractionBlock(torch.nn.Module):
    def __init__(self, hidden_channels: int, num_bilinear: int,
                 num_spherical: int, num_radial: int, num_before_skip: int,
                 num_after_skip: int, act: Callable):
        super().__init__()
        self.act = act

        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)
        self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear,
                              bias=False)

        # Dense transformations of input messages.
        self.lin_kj = Linear(hidden_channels, hidden_channels)
        self.lin_ji = Linear(hidden_channels, hidden_channels)

        self.W = torch.nn.Parameter(
            torch.Tensor(hidden_channels, num_bilinear, hidden_channels))

        self.layers_before_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)
        ])
        self.lin = Linear(hidden_channels, hidden_channels)
        self.layers_after_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf.weight, scale=2.0)
        glorot_orthogonal(self.lin_kj.weight, scale=2.0)
        self.lin_kj.bias.data.fill_(0)
        glorot_orthogonal(self.lin_ji.weight, scale=2.0)
        self.lin_ji.bias.data.fill_(0)
        self.W.data.normal_(mean=0, std=2 / self.W.size(0))
        for res_layer in self.layers_before_skip:
            res_layer.reset_parameters()
        glorot_orthogonal(self.lin.weight, scale=2.0)
        self.lin.bias.data.fill_(0)
        for res_layer in self.layers_after_skip:
            res_layer.reset_parameters()

    def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,
                idx_ji: Tensor) -> Tensor:
        rbf = self.lin_rbf(rbf)
        sbf = self.lin_sbf(sbf)

        x_ji = self.act(self.lin_ji(x))
        x_kj = self.act(self.lin_kj(x))
        x_kj = x_kj * rbf
        x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W)
        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0))

        h = x_ji + x_kj
        for layer in self.layers_before_skip:
            h = layer(h)
        h = self.act(self.lin(h)) + x
        for layer in self.layers_after_skip:
            h = layer(h)

        return h


class InteractionPPBlock(torch.nn.Module):
    def __init__(self, hidden_channels: int, int_emb_size: int,
                 basis_emb_size: int, num_spherical: int, num_radial: int,
                 num_before_skip: int, num_after_skip: int, act: Callable):
        super().__init__()
        self.act = act

        # Transformation of Bessel and spherical basis representations:
        self.lin_rbf1 = Linear(num_radial, basis_emb_size, bias=False)
        self.lin_rbf2 = Linear(basis_emb_size, hidden_channels, bias=False)

        self.lin_sbf1 = Linear(num_spherical * num_radial, basis_emb_size,
                               bias=False)
        self.lin_sbf2 = Linear(basis_emb_size, int_emb_size, bias=False)

        # Hidden transformation of input message:
        self.lin_kj = Linear(hidden_channels, hidden_channels)
        self.lin_ji = Linear(hidden_channels, hidden_channels)

        # Embedding projections for interaction triplets:
        self.lin_down = Linear(hidden_channels, int_emb_size, bias=False)
        self.lin_up = Linear(int_emb_size, hidden_channels, bias=False)

        # Residual layers before and after skip connection:
        self.layers_before_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)
        ])
        self.lin = Linear(hidden_channels, hidden_channels)
        self.layers_after_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf1.weight, scale=2.0)
        glorot_orthogonal(self.lin_rbf2.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf1.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf2.weight, scale=2.0)

        glorot_orthogonal(self.lin_kj.weight, scale=2.0)
        self.lin_kj.bias.data.fill_(0)
        glorot_orthogonal(self.lin_ji.weight, scale=2.0)
        self.lin_ji.bias.data.fill_(0)

        glorot_orthogonal(self.lin_down.weight, scale=2.0)
        glorot_orthogonal(self.lin_up.weight, scale=2.0)

        for res_layer in self.layers_before_skip:
            res_layer.reset_parameters()
        glorot_orthogonal(self.lin.weight, scale=2.0)
        self.lin.bias.data.fill_(0)
        for res_layer in self.layers_before_skip:
            res_layer.reset_parameters()

    def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,
                idx_ji: Tensor) -> Tensor:
        # Initial transformation:
        x_ji = self.act(self.lin_ji(x))
        x_kj = self.act(self.lin_kj(x))

        # Transformation via Bessel basis:
        rbf = self.lin_rbf1(rbf)
        rbf = self.lin_rbf2(rbf)
        x_kj = x_kj * rbf

        # Down project embedding and generating triple-interactions:
        x_kj = self.act(self.lin_down(x_kj))

        # Transform via 2D spherical basis:
        sbf = self.lin_sbf1(sbf)
        sbf = self.lin_sbf2(sbf)
        x_kj = x_kj[idx_kj] * sbf

        # Aggregate interactions and up-project embeddings:
        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0))
        x_kj = self.act(self.lin_up(x_kj))

        h = x_ji + x_kj
        for layer in self.layers_before_skip:
            h = layer(h)
        h = self.act(self.lin(h)) + x
        for layer in self.layers_after_skip:
            h = layer(h)

        return h


class OutputBlock(torch.nn.Module):
    def __init__(self, num_radial: int, hidden_channels: int,
                 out_channels: int, num_layers: int, act: Callable):
        super().__init__()
        self.act = act

        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)
        self.lins = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.lins.append(Linear(hidden_channels, hidden_channels))
        self.lin = Linear(hidden_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
        for lin in self.lins:
            glorot_orthogonal(lin.weight, scale=2.0)
            lin.bias.data.fill_(0)
        self.lin.weight.data.fill_(0)

    def forward(self, x: Tensor, rbf: Tensor, i: Tensor,
                num_nodes: Optional[int] = None) -> Tensor:
        x = self.lin_rbf(rbf) * x
        x = scatter(x, i, dim=0, dim_size=num_nodes)
        for lin in self.lins:
            x = self.act(lin(x))
        return self.lin(x)


class OutputPPBlock(torch.nn.Module):
    def __init__(self, num_radial: int, hidden_channels: int,
                 out_emb_channels: int, out_channels: int, num_layers: int,
                 act: Callable):
        super().__init__()
        self.act = act

        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)

        # The up-projection layer:
        self.lin_up = Linear(hidden_channels, out_emb_channels, bias=False)
        self.lins = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.lins.append(Linear(out_emb_channels, out_emb_channels))
        self.lin = Linear(out_emb_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
        glorot_orthogonal(self.lin_up.weight, scale=2.0)
        for lin in self.lins:
            glorot_orthogonal(lin.weight, scale=2.0)
            lin.bias.data.fill_(0)
        self.lin.weight.data.fill_(0)

    def forward(self, x: Tensor, rbf: Tensor, i: Tensor,
                num_nodes: Optional[int] = None) -> Tensor:
        x = self.lin_rbf(rbf) * x
        x = scatter(x, i, dim=0, dim_size=num_nodes)
        x = self.lin_up(x)
        for lin in self.lins:
            x = self.act(lin(x))
        return self.lin(x)


[docs]class DimeNet(torch.nn.Module): r"""The directional message passing neural network (DimeNet) from the `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_ paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion. .. note:: For an example of using a pretrained DimeNet variant, see `examples/qm9_pretrained_dimenet.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ qm9_pretrained_dimenet.py>`_. Args: hidden_channels (int): Hidden embedding size. out_channels (int): Size of each output sample. num_blocks (int): Number of building blocks. num_bilinear (int): Size of the bilinear layer tensor. num_spherical (int): Number of spherical harmonics. num_radial (int): Number of radial basis functions. cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`5.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) envelope_exponent (int, optional): Shape of the smooth cutoff. (default: :obj:`5`) num_before_skip (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: :obj:`1`) num_after_skip (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: :obj:`2`) num_output_layers (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`) act (str or Callable, optional): The activation function. (default: :obj:`"swish"`) """ url = ('https://github.com/klicperajo/dimenet/raw/master/pretrained/' 'dimenet') def __init__( self, hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', ): super().__init__() if num_spherical < 2: raise ValueError("num_spherical should be greater than 1") act = activation_resolver(act) self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors self.num_blocks = num_blocks self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent) self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff, envelope_exponent) self.emb = EmbeddingBlock(num_radial, hidden_channels, act) self.output_blocks = torch.nn.ModuleList([ OutputBlock(num_radial, hidden_channels, out_channels, num_output_layers, act) for _ in range(num_blocks + 1) ]) self.interaction_blocks = torch.nn.ModuleList([ InteractionBlock(hidden_channels, num_bilinear, num_spherical, num_radial, num_before_skip, num_after_skip, act) for _ in range(num_blocks) ]) self.reset_parameters()
[docs] def reset_parameters(self): self.rbf.reset_parameters() self.emb.reset_parameters() for out in self.output_blocks: out.reset_parameters() for interaction in self.interaction_blocks: interaction.reset_parameters()
[docs] @classmethod def from_qm9_pretrained( cls, root: str, dataset: Dataset, target: int, ) -> Tuple['DimeNet', Dataset, Dataset, Dataset]: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf assert target >= 0 and target <= 12 and not target == 4 root = osp.expanduser(osp.normpath(root)) path = osp.join(root, 'pretrained_dimenet', qm9_target_dict[target]) makedirs(path) url = f'{cls.url}/{qm9_target_dict[target]}' if not osp.exists(osp.join(path, 'checkpoint')): download_url(f'{url}/checkpoint', path) download_url(f'{url}/ckpt.data-00000-of-00002', path) download_url(f'{url}/ckpt.data-00001-of-00002', path) download_url(f'{url}/ckpt.index', path) path = osp.join(path, 'ckpt') reader = tf.train.load_checkpoint(path) model = cls( hidden_channels=128, out_channels=1, num_blocks=6, num_bilinear=8, num_spherical=7, num_radial=6, cutoff=5.0, envelope_exponent=5, num_before_skip=1, num_after_skip=2, num_output_layers=3, ) def copy_(src, name, transpose=False): init = reader.get_tensor(f'{name}/.ATTRIBUTES/VARIABLE_VALUE') init = torch.from_numpy(init) if name[-6:] == 'kernel': init = init.t() src.data.copy_(init) copy_(model.rbf.freq, 'rbf_layer/frequencies') copy_(model.emb.emb.weight, 'emb_block/embeddings') copy_(model.emb.lin_rbf.weight, 'emb_block/dense_rbf/kernel') copy_(model.emb.lin_rbf.bias, 'emb_block/dense_rbf/bias') copy_(model.emb.lin.weight, 'emb_block/dense/kernel') copy_(model.emb.lin.bias, 'emb_block/dense/bias') for i, block in enumerate(model.output_blocks): copy_(block.lin_rbf.weight, f'output_blocks/{i}/dense_rbf/kernel') for j, lin in enumerate(block.lins): copy_(lin.weight, f'output_blocks/{i}/dense_layers/{j}/kernel') copy_(lin.bias, f'output_blocks/{i}/dense_layers/{j}/bias') copy_(block.lin.weight, f'output_blocks/{i}/dense_final/kernel') for i, block in enumerate(model.interaction_blocks): copy_(block.lin_rbf.weight, f'int_blocks/{i}/dense_rbf/kernel') copy_(block.lin_sbf.weight, f'int_blocks/{i}/dense_sbf/kernel') copy_(block.lin_kj.weight, f'int_blocks/{i}/dense_kj/kernel') copy_(block.lin_kj.bias, f'int_blocks/{i}/dense_kj/bias') copy_(block.lin_ji.weight, f'int_blocks/{i}/dense_ji/kernel') copy_(block.lin_ji.bias, f'int_blocks/{i}/dense_ji/bias') copy_(block.W, f'int_blocks/{i}/bilinear') for j, layer in enumerate(block.layers_before_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/bias') copy_(block.lin.weight, f'int_blocks/{i}/final_before_skip/kernel') copy_(block.lin.bias, f'int_blocks/{i}/final_before_skip/bias') for j, layer in enumerate(block.layers_after_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/bias') # Use the same random seed as the official DimeNet` implementation. random_state = np.random.RandomState(seed=42) perm = torch.from_numpy(random_state.permutation(np.arange(130831))) train_idx = perm[:110000] val_idx = perm[110000:120000] test_idx = perm[120000:] return model, (dataset[train_idx], dataset[val_idx], dataset[test_idx])
[docs] def triplets( self, edge_index: Tensor, num_nodes: int, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: row, col = edge_index # j->i value = torch.arange(row.size(0), device=row.device) adj_t = SparseTensor(row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes)) adj_t_row = adj_t[row] num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) # Node indices (k->j->i) for triplets. idx_i = col.repeat_interleave(num_triplets) idx_j = row.repeat_interleave(num_triplets) idx_k = adj_t_row.storage.col() mask = idx_i != idx_k # Remove i == k triplets. idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] # Edge indices (k-j, j->i) for triplets. idx_kj = adj_t_row.storage.value()[mask] idx_ji = adj_t_row.storage.row()[mask] return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji
[docs] def forward( self, z: Tensor, pos: Tensor, batch: OptTensor = None, ) -> Tensor: """""" edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( edge_index, num_nodes=z.size(0)) # Calculate distances. dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() # Calculate angles. pos_i = pos[idx_i] pos_ji, pos_ki = pos[idx_j] - pos_i, pos[idx_k] - pos_i a = (pos_ji * pos_ki).sum(dim=-1) b = torch.cross(pos_ji, pos_ki).norm(dim=-1) angle = torch.atan2(b, a) rbf = self.rbf(dist) sbf = self.sbf(dist, angle, idx_kj) # Embedding block. x = self.emb(z, rbf, i, j) P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) # Interaction blocks. for interaction_block, output_block in zip(self.interaction_blocks, self.output_blocks[1:]): x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) P = P + output_block(x, rbf, i) return P.sum(dim=0) if batch is None else scatter(P, batch, dim=0)
[docs]class DimeNetPlusPlus(DimeNet): r"""The DimeNet++ from the `"Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules" <https://arxiv.org/abs/2011.14115>`_ paper. :class:`DimeNetPlusPlus` is an upgrade to the :class:`DimeNet` model with 8x faster and 10% more accurate than :class:`DimeNet`. Args: hidden_channels (int): Hidden embedding size. out_channels (int): Size of each output sample. num_blocks (int): Number of building blocks. int_emb_size (int): Size of embedding in the interaction block. basis_emb_size (int): Size of basis embedding in the interaction block. out_emb_channels (int): Size of embedding in the output block. num_spherical (int): Number of spherical harmonics. num_radial (int): Number of radial basis functions. cutoff: (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`5.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) envelope_exponent (int, optional): Shape of the smooth cutoff. (default: :obj:`5`) num_before_skip: (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: :obj:`1`) num_after_skip: (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: :obj:`2`) num_output_layers: (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`) act: (str or Callable, optional): The activation funtion. (default: :obj:`"swish"`) """ url = ('https://raw.githubusercontent.com/gasteigerjo/dimenet/' 'master/pretrained/dimenet_pp') def __init__( self, hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', ): act = activation_resolver(act) super().__init__( hidden_channels=hidden_channels, out_channels=out_channels, num_blocks=num_blocks, num_bilinear=1, num_spherical=num_spherical, num_radial=num_radial, cutoff=cutoff, max_num_neighbors=max_num_neighbors, envelope_exponent=envelope_exponent, num_before_skip=num_before_skip, num_after_skip=num_after_skip, num_output_layers=num_output_layers, act=act, ) # We are re-using the RBF, SBF and embedding layers of `DimeNet` and # redefine output_block and interaction_block in DimeNet++. # Hence, it is to be noted that in the above initalization, the # variable `num_bilinear` does not have any purpose as it is used # solely in the `OutputBlock` of DimeNet: self.output_blocks = torch.nn.ModuleList([ OutputPPBlock(num_radial, hidden_channels, out_emb_channels, out_channels, num_output_layers, act) for _ in range(num_blocks + 1) ]) self.interaction_blocks = torch.nn.ModuleList([ InteractionPPBlock(hidden_channels, int_emb_size, basis_emb_size, num_spherical, num_radial, num_before_skip, num_after_skip, act) for _ in range(num_blocks) ]) self.reset_parameters()
[docs] @classmethod def from_qm9_pretrained( cls, root: str, dataset: Dataset, target: int, ) -> Tuple['DimeNetPlusPlus', Dataset, Dataset, Dataset]: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf assert target >= 0 and target <= 12 and not target == 4 root = osp.expanduser(osp.normpath(root)) path = osp.join(root, 'pretrained_dimenet_pp', qm9_target_dict[target]) makedirs(path) url = f'{cls.url}/{qm9_target_dict[target]}' if not osp.exists(osp.join(path, 'checkpoint')): download_url(f'{url}/checkpoint', path) download_url(f'{url}/ckpt.data-00000-of-00002', path) download_url(f'{url}/ckpt.data-00001-of-00002', path) download_url(f'{url}/ckpt.index', path) path = osp.join(path, 'ckpt') reader = tf.train.load_checkpoint(path) # Configuration from DimeNet++: # https://github.com/gasteigerjo/dimenet/blob/master/config_pp.yaml model = cls( hidden_channels=128, out_channels=1, num_blocks=4, int_emb_size=64, basis_emb_size=8, out_emb_channels=256, num_spherical=7, num_radial=6, cutoff=5.0, max_num_neighbors=32, envelope_exponent=5, num_before_skip=1, num_after_skip=2, num_output_layers=3, ) def copy_(src, name, transpose=False): init = reader.get_tensor(f'{name}/.ATTRIBUTES/VARIABLE_VALUE') init = torch.from_numpy(init) if name[-6:] == 'kernel': init = init.t() src.data.copy_(init) copy_(model.rbf.freq, 'rbf_layer/frequencies') copy_(model.emb.emb.weight, 'emb_block/embeddings') copy_(model.emb.lin_rbf.weight, 'emb_block/dense_rbf/kernel') copy_(model.emb.lin_rbf.bias, 'emb_block/dense_rbf/bias') copy_(model.emb.lin.weight, 'emb_block/dense/kernel') copy_(model.emb.lin.bias, 'emb_block/dense/bias') for i, block in enumerate(model.output_blocks): copy_(block.lin_rbf.weight, f'output_blocks/{i}/dense_rbf/kernel') copy_(block.lin_up.weight, f'output_blocks/{i}/up_projection/kernel') for j, lin in enumerate(block.lins): copy_(lin.weight, f'output_blocks/{i}/dense_layers/{j}/kernel') copy_(lin.bias, f'output_blocks/{i}/dense_layers/{j}/bias') copy_(block.lin.weight, f'output_blocks/{i}/dense_final/kernel') for i, block in enumerate(model.interaction_blocks): copy_(block.lin_rbf1.weight, f'int_blocks/{i}/dense_rbf1/kernel') copy_(block.lin_rbf2.weight, f'int_blocks/{i}/dense_rbf2/kernel') copy_(block.lin_sbf1.weight, f'int_blocks/{i}/dense_sbf1/kernel') copy_(block.lin_sbf2.weight, f'int_blocks/{i}/dense_sbf2/kernel') copy_(block.lin_ji.weight, f'int_blocks/{i}/dense_ji/kernel') copy_(block.lin_ji.bias, f'int_blocks/{i}/dense_ji/bias') copy_(block.lin_kj.weight, f'int_blocks/{i}/dense_kj/kernel') copy_(block.lin_kj.bias, f'int_blocks/{i}/dense_kj/bias') copy_(block.lin_down.weight, f'int_blocks/{i}/down_projection/kernel') copy_(block.lin_up.weight, f'int_blocks/{i}/up_projection/kernel') for j, layer in enumerate(block.layers_before_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/bias') copy_(block.lin.weight, f'int_blocks/{i}/final_before_skip/kernel') copy_(block.lin.bias, f'int_blocks/{i}/final_before_skip/bias') for j, layer in enumerate(block.layers_after_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/bias') random_state = np.random.RandomState(seed=42) perm = torch.from_numpy(random_state.permutation(np.arange(130831))) train_idx = perm[:110000] val_idx = perm[110000:120000] test_idx = perm[120000:] return model, (dataset[train_idx], dataset[val_idx], dataset[test_idx])