Source code for torch_geometric.nn.models.dimenet

import os
import os.path as osp
from functools import partial
from math import pi as PI
from math import sqrt
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch.nn import Embedding, Linear

from torch_geometric.data import Dataset, download_url
from torch_geometric.nn import radius_graph
from torch_geometric.nn.inits import glorot_orthogonal
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import OptTensor, SparseTensor
from torch_geometric.utils import scatter

qm9_target_dict: Dict[int, str] = {
    0: 'mu',
    1: 'alpha',
    2: 'homo',
    3: 'lumo',
    5: 'r2',
    6: 'zpve',
    7: 'U0',
    8: 'U',
    9: 'H',
    10: 'G',
    11: 'Cv',
}


class Envelope(torch.nn.Module):
    def __init__(self, exponent: int):
        super().__init__()
        self.p = exponent + 1
        self.a = -(self.p + 1) * (self.p + 2) / 2
        self.b = self.p * (self.p + 2)
        self.c = -self.p * (self.p + 1) / 2

    def forward(self, x: Tensor) -> Tensor:
        p, a, b, c = self.p, self.a, self.b, self.c
        x_pow_p0 = x.pow(p - 1)
        x_pow_p1 = x_pow_p0 * x
        x_pow_p2 = x_pow_p1 * x
        return (1.0 / x + a * x_pow_p0 + b * x_pow_p1 +
                c * x_pow_p2) * (x < 1.0).to(x.dtype)


class BesselBasisLayer(torch.nn.Module):
    def __init__(self, num_radial: int, cutoff: float = 5.0,
                 envelope_exponent: int = 5):
        super().__init__()
        self.cutoff = cutoff
        self.envelope = Envelope(envelope_exponent)

        self.freq = torch.nn.Parameter(torch.empty(num_radial))

        self.reset_parameters()

    def reset_parameters(self):
        with torch.no_grad():
            torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
        self.freq.requires_grad_()

    def forward(self, dist: Tensor) -> Tensor:
        dist = dist.unsqueeze(-1) / self.cutoff
        return self.envelope(dist) * (self.freq * dist).sin()


class SphericalBasisLayer(torch.nn.Module):
    def __init__(
        self,
        num_spherical: int,
        num_radial: int,
        cutoff: float = 5.0,
        envelope_exponent: int = 5,
    ):
        super().__init__()
        import sympy as sym

        from torch_geometric.nn.models.dimenet_utils import (
            bessel_basis,
            real_sph_harm,
        )

        assert num_radial <= 64
        self.num_spherical = num_spherical
        self.num_radial = num_radial
        self.cutoff = cutoff
        self.envelope = Envelope(envelope_exponent)

        bessel_forms = bessel_basis(num_spherical, num_radial)
        sph_harm_forms = real_sph_harm(num_spherical)
        self.sph_funcs = []
        self.bessel_funcs = []

        x, theta = sym.symbols('x theta')
        modules = {'sin': torch.sin, 'cos': torch.cos}
        for i in range(num_spherical):
            if i == 0:
                sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0)
                self.sph_funcs.append(partial(self._sph_to_tensor, sph1))
            else:
                sph = sym.lambdify([theta], sph_harm_forms[i][0], modules)
                self.sph_funcs.append(sph)
            for j in range(num_radial):
                bessel = sym.lambdify([x], bessel_forms[i][j], modules)
                self.bessel_funcs.append(bessel)

    @staticmethod
    def _sph_to_tensor(sph, x: Tensor) -> Tensor:
        return torch.zeros_like(x) + sph

    def forward(self, dist: Tensor, angle: Tensor, idx_kj: Tensor) -> Tensor:
        dist = dist / self.cutoff
        rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)
        rbf = self.envelope(dist).unsqueeze(-1) * rbf

        cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1)

        n, k = self.num_spherical, self.num_radial
        out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k)
        return out


class EmbeddingBlock(torch.nn.Module):
    def __init__(self, num_radial: int, hidden_channels: int, act: Callable):
        super().__init__()
        self.act = act

        self.emb = Embedding(95, hidden_channels)
        self.lin_rbf = Linear(num_radial, hidden_channels)
        self.lin = Linear(3 * hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
        self.lin_rbf.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor:
        x = self.emb(x)
        rbf = self.act(self.lin_rbf(rbf))
        return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))


class ResidualLayer(torch.nn.Module):
    def __init__(self, hidden_channels: int, act: Callable):
        super().__init__()
        self.act = act
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin1.weight, scale=2.0)
        self.lin1.bias.data.fill_(0)
        glorot_orthogonal(self.lin2.weight, scale=2.0)
        self.lin2.bias.data.fill_(0)

    def forward(self, x: Tensor) -> Tensor:
        return x + self.act(self.lin2(self.act(self.lin1(x))))


class InteractionBlock(torch.nn.Module):
    def __init__(
        self,
        hidden_channels: int,
        num_bilinear: int,
        num_spherical: int,
        num_radial: int,
        num_before_skip: int,
        num_after_skip: int,
        act: Callable,
    ):
        super().__init__()
        self.act = act

        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)
        self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear,
                              bias=False)

        # Dense transformations of input messages.
        self.lin_kj = Linear(hidden_channels, hidden_channels)
        self.lin_ji = Linear(hidden_channels, hidden_channels)

        self.W = torch.nn.Parameter(
            torch.empty(hidden_channels, num_bilinear, hidden_channels))

        self.layers_before_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)
        ])
        self.lin = Linear(hidden_channels, hidden_channels)
        self.layers_after_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf.weight, scale=2.0)
        glorot_orthogonal(self.lin_kj.weight, scale=2.0)
        self.lin_kj.bias.data.fill_(0)
        glorot_orthogonal(self.lin_ji.weight, scale=2.0)
        self.lin_ji.bias.data.fill_(0)
        self.W.data.normal_(mean=0, std=2 / self.W.size(0))
        for res_layer in self.layers_before_skip:
            res_layer.reset_parameters()
        glorot_orthogonal(self.lin.weight, scale=2.0)
        self.lin.bias.data.fill_(0)
        for res_layer in self.layers_after_skip:
            res_layer.reset_parameters()

    def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,
                idx_ji: Tensor) -> Tensor:
        rbf = self.lin_rbf(rbf)
        sbf = self.lin_sbf(sbf)

        x_ji = self.act(self.lin_ji(x))
        x_kj = self.act(self.lin_kj(x))
        x_kj = x_kj * rbf
        x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W)
        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0), reduce='sum')

        h = x_ji + x_kj
        for layer in self.layers_before_skip:
            h = layer(h)
        h = self.act(self.lin(h)) + x
        for layer in self.layers_after_skip:
            h = layer(h)

        return h


class InteractionPPBlock(torch.nn.Module):
    def __init__(
        self,
        hidden_channels: int,
        int_emb_size: int,
        basis_emb_size: int,
        num_spherical: int,
        num_radial: int,
        num_before_skip: int,
        num_after_skip: int,
        act: Callable,
    ):
        super().__init__()
        self.act = act

        # Transformation of Bessel and spherical basis representations:
        self.lin_rbf1 = Linear(num_radial, basis_emb_size, bias=False)
        self.lin_rbf2 = Linear(basis_emb_size, hidden_channels, bias=False)

        self.lin_sbf1 = Linear(num_spherical * num_radial, basis_emb_size,
                               bias=False)
        self.lin_sbf2 = Linear(basis_emb_size, int_emb_size, bias=False)

        # Hidden transformation of input message:
        self.lin_kj = Linear(hidden_channels, hidden_channels)
        self.lin_ji = Linear(hidden_channels, hidden_channels)

        # Embedding projections for interaction triplets:
        self.lin_down = Linear(hidden_channels, int_emb_size, bias=False)
        self.lin_up = Linear(int_emb_size, hidden_channels, bias=False)

        # Residual layers before and after skip connection:
        self.layers_before_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)
        ])
        self.lin = Linear(hidden_channels, hidden_channels)
        self.layers_after_skip = torch.nn.ModuleList([
            ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)
        ])

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf1.weight, scale=2.0)
        glorot_orthogonal(self.lin_rbf2.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf1.weight, scale=2.0)
        glorot_orthogonal(self.lin_sbf2.weight, scale=2.0)

        glorot_orthogonal(self.lin_kj.weight, scale=2.0)
        self.lin_kj.bias.data.fill_(0)
        glorot_orthogonal(self.lin_ji.weight, scale=2.0)
        self.lin_ji.bias.data.fill_(0)

        glorot_orthogonal(self.lin_down.weight, scale=2.0)
        glorot_orthogonal(self.lin_up.weight, scale=2.0)

        for res_layer in self.layers_before_skip:
            res_layer.reset_parameters()
        glorot_orthogonal(self.lin.weight, scale=2.0)
        self.lin.bias.data.fill_(0)
        for res_layer in self.layers_after_skip:
            res_layer.reset_parameters()

    def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,
                idx_ji: Tensor) -> Tensor:
        # Initial transformation:
        x_ji = self.act(self.lin_ji(x))
        x_kj = self.act(self.lin_kj(x))

        # Transformation via Bessel basis:
        rbf = self.lin_rbf1(rbf)
        rbf = self.lin_rbf2(rbf)
        x_kj = x_kj * rbf

        # Down project embedding and generating triple-interactions:
        x_kj = self.act(self.lin_down(x_kj))

        # Transform via 2D spherical basis:
        sbf = self.lin_sbf1(sbf)
        sbf = self.lin_sbf2(sbf)
        x_kj = x_kj[idx_kj] * sbf

        # Aggregate interactions and up-project embeddings:
        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0), reduce='sum')
        x_kj = self.act(self.lin_up(x_kj))

        h = x_ji + x_kj
        for layer in self.layers_before_skip:
            h = layer(h)
        h = self.act(self.lin(h)) + x
        for layer in self.layers_after_skip:
            h = layer(h)

        return h


class OutputBlock(torch.nn.Module):
    def __init__(
        self,
        num_radial: int,
        hidden_channels: int,
        out_channels: int,
        num_layers: int,
        act: Callable,
        output_initializer: str = 'zeros',
    ):
        assert output_initializer in {'zeros', 'glorot_orthogonal'}

        super().__init__()

        self.act = act
        self.output_initializer = output_initializer

        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)
        self.lins = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.lins.append(Linear(hidden_channels, hidden_channels))
        self.lin = Linear(hidden_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
        for lin in self.lins:
            glorot_orthogonal(lin.weight, scale=2.0)
            lin.bias.data.fill_(0)
        if self.output_initializer == 'zeros':
            self.lin.weight.data.fill_(0)
        elif self.output_initializer == 'glorot_orthogonal':
            glorot_orthogonal(self.lin.weight, scale=2.0)

    def forward(self, x: Tensor, rbf: Tensor, i: Tensor,
                num_nodes: Optional[int] = None) -> Tensor:
        x = self.lin_rbf(rbf) * x
        x = scatter(x, i, dim=0, dim_size=num_nodes, reduce='sum')
        for lin in self.lins:
            x = self.act(lin(x))
        return self.lin(x)


class OutputPPBlock(torch.nn.Module):
    def __init__(
        self,
        num_radial: int,
        hidden_channels: int,
        out_emb_channels: int,
        out_channels: int,
        num_layers: int,
        act: Callable,
        output_initializer: str = 'zeros',
    ):
        assert output_initializer in {'zeros', 'glorot_orthogonal'}

        super().__init__()

        self.act = act
        self.output_initializer = output_initializer

        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)

        # The up-projection layer:
        self.lin_up = Linear(hidden_channels, out_emb_channels, bias=False)
        self.lins = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.lins.append(Linear(out_emb_channels, out_emb_channels))
        self.lin = Linear(out_emb_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
        glorot_orthogonal(self.lin_up.weight, scale=2.0)
        for lin in self.lins:
            glorot_orthogonal(lin.weight, scale=2.0)
            lin.bias.data.fill_(0)
        if self.output_initializer == 'zeros':
            self.lin.weight.data.fill_(0)
        elif self.output_initializer == 'glorot_orthogonal':
            glorot_orthogonal(self.lin.weight, scale=2.0)

    def forward(self, x: Tensor, rbf: Tensor, i: Tensor,
                num_nodes: Optional[int] = None) -> Tensor:
        x = self.lin_rbf(rbf) * x
        x = scatter(x, i, dim=0, dim_size=num_nodes, reduce='sum')
        x = self.lin_up(x)
        for lin in self.lins:
            x = self.act(lin(x))
        return self.lin(x)


def triplets(
    edge_index: Tensor,
    num_nodes: int,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
    row, col = edge_index  # j->i

    value = torch.arange(row.size(0), device=row.device)
    adj_t = SparseTensor(row=col, col=row, value=value,
                         sparse_sizes=(num_nodes, num_nodes))
    adj_t_row = adj_t[row]
    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

    # Node indices (k->j->i) for triplets.
    idx_i = col.repeat_interleave(num_triplets)
    idx_j = row.repeat_interleave(num_triplets)
    idx_k = adj_t_row.storage.col()
    mask = idx_i != idx_k  # Remove i == k triplets.
    idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

    # Edge indices (k-j, j->i) for triplets.
    idx_kj = adj_t_row.storage.value()[mask]
    idx_ji = adj_t_row.storage.row()[mask]

    return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji


[docs]class DimeNet(torch.nn.Module): r"""The directional message passing neural network (DimeNet) from the `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_ paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion. .. note:: For an example of using a pretrained DimeNet variant, see `examples/qm9_pretrained_dimenet.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ qm9_pretrained_dimenet.py>`_. Args: hidden_channels (int): Hidden embedding size. out_channels (int): Size of each output sample. num_blocks (int): Number of building blocks. num_bilinear (int): Size of the bilinear layer tensor. num_spherical (int): Number of spherical harmonics. num_radial (int): Number of radial basis functions. cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`5.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) envelope_exponent (int, optional): Shape of the smooth cutoff. (default: :obj:`5`) num_before_skip (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: :obj:`1`) num_after_skip (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: :obj:`2`) num_output_layers (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`) act (str or Callable, optional): The activation function. (default: :obj:`"swish"`) output_initializer (str, optional): The initialization method for the output layer (:obj:`"zeros"`, :obj:`"glorot_orthogonal"`). (default: :obj:`"zeros"`) """ url = ('https://github.com/klicperajo/dimenet/raw/master/pretrained/' 'dimenet') def __init__( self, hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros', ): super().__init__() if num_spherical < 2: raise ValueError("'num_spherical' should be greater than 1") act = activation_resolver(act) self.cutoff = cutoff self.max_num_neighbors = max_num_neighbors self.num_blocks = num_blocks self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent) self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff, envelope_exponent) self.emb = EmbeddingBlock(num_radial, hidden_channels, act) self.output_blocks = torch.nn.ModuleList([ OutputBlock( num_radial, hidden_channels, out_channels, num_output_layers, act, output_initializer, ) for _ in range(num_blocks + 1) ]) self.interaction_blocks = torch.nn.ModuleList([ InteractionBlock( hidden_channels, num_bilinear, num_spherical, num_radial, num_before_skip, num_after_skip, act, ) for _ in range(num_blocks) ])
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.rbf.reset_parameters() self.emb.reset_parameters() for out in self.output_blocks: out.reset_parameters() for interaction in self.interaction_blocks: interaction.reset_parameters()
[docs] @classmethod def from_qm9_pretrained( cls, root: str, dataset: Dataset, target: int, ) -> Tuple['DimeNet', Dataset, Dataset, Dataset]: # pragma: no cover r"""Returns a pre-trained :class:`DimeNet` model on the :class:`~torch_geometric.datasets.QM9` dataset, trained on the specified target :obj:`target`. """ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf assert target >= 0 and target <= 12 and not target == 4 root = osp.expanduser(osp.normpath(root)) path = osp.join(root, 'pretrained_dimenet', qm9_target_dict[target]) os.makedirs(path, exist_ok=True) url = f'{cls.url}/{qm9_target_dict[target]}' if not osp.exists(osp.join(path, 'checkpoint')): download_url(f'{url}/checkpoint', path) download_url(f'{url}/ckpt.data-00000-of-00002', path) download_url(f'{url}/ckpt.data-00001-of-00002', path) download_url(f'{url}/ckpt.index', path) path = osp.join(path, 'ckpt') reader = tf.train.load_checkpoint(path) model = cls( hidden_channels=128, out_channels=1, num_blocks=6, num_bilinear=8, num_spherical=7, num_radial=6, cutoff=5.0, envelope_exponent=5, num_before_skip=1, num_after_skip=2, num_output_layers=3, ) def copy_(src, name, transpose=False): init = reader.get_tensor(f'{name}/.ATTRIBUTES/VARIABLE_VALUE') init = torch.from_numpy(init) if name[-6:] == 'kernel': init = init.t() src.data.copy_(init) copy_(model.rbf.freq, 'rbf_layer/frequencies') copy_(model.emb.emb.weight, 'emb_block/embeddings') copy_(model.emb.lin_rbf.weight, 'emb_block/dense_rbf/kernel') copy_(model.emb.lin_rbf.bias, 'emb_block/dense_rbf/bias') copy_(model.emb.lin.weight, 'emb_block/dense/kernel') copy_(model.emb.lin.bias, 'emb_block/dense/bias') for i, block in enumerate(model.output_blocks): copy_(block.lin_rbf.weight, f'output_blocks/{i}/dense_rbf/kernel') for j, lin in enumerate(block.lins): copy_(lin.weight, f'output_blocks/{i}/dense_layers/{j}/kernel') copy_(lin.bias, f'output_blocks/{i}/dense_layers/{j}/bias') copy_(block.lin.weight, f'output_blocks/{i}/dense_final/kernel') for i, block in enumerate(model.interaction_blocks): copy_(block.lin_rbf.weight, f'int_blocks/{i}/dense_rbf/kernel') copy_(block.lin_sbf.weight, f'int_blocks/{i}/dense_sbf/kernel') copy_(block.lin_kj.weight, f'int_blocks/{i}/dense_kj/kernel') copy_(block.lin_kj.bias, f'int_blocks/{i}/dense_kj/bias') copy_(block.lin_ji.weight, f'int_blocks/{i}/dense_ji/kernel') copy_(block.lin_ji.bias, f'int_blocks/{i}/dense_ji/bias') copy_(block.W, f'int_blocks/{i}/bilinear') for j, layer in enumerate(block.layers_before_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/bias') copy_(block.lin.weight, f'int_blocks/{i}/final_before_skip/kernel') copy_(block.lin.bias, f'int_blocks/{i}/final_before_skip/bias') for j, layer in enumerate(block.layers_after_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/bias') # Use the same random seed as the official DimeNet` implementation. random_state = np.random.RandomState(seed=42) perm = torch.from_numpy(random_state.permutation(np.arange(130831))) perm = perm.long() train_idx = perm[:110000] val_idx = perm[110000:120000] test_idx = perm[120000:] return model, (dataset[train_idx], dataset[val_idx], dataset[test_idx])
[docs] def forward( self, z: Tensor, pos: Tensor, batch: OptTensor = None, ) -> Tensor: r"""Forward pass. Args: z (torch.Tensor): Atomic number of each atom with shape :obj:`[num_atoms]`. pos (torch.Tensor): Coordinates of each atom with shape :obj:`[num_atoms, 3]`. batch (torch.Tensor, optional): Batch indices assigning each atom to a separate molecule with shape :obj:`[num_atoms]`. (default: :obj:`None`) """ edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( edge_index, num_nodes=z.size(0)) # Calculate distances. dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() # Calculate angles. if isinstance(self, DimeNetPlusPlus): pos_jk, pos_ij = pos[idx_j] - pos[idx_k], pos[idx_i] - pos[idx_j] a = (pos_ij * pos_jk).sum(dim=-1) b = torch.cross(pos_ij, pos_jk, dim=1).norm(dim=-1) elif isinstance(self, DimeNet): pos_ji, pos_ki = pos[idx_j] - pos[idx_i], pos[idx_k] - pos[idx_i] a = (pos_ji * pos_ki).sum(dim=-1) b = torch.cross(pos_ji, pos_ki, dim=1).norm(dim=-1) angle = torch.atan2(b, a) rbf = self.rbf(dist) sbf = self.sbf(dist, angle, idx_kj) # Embedding block. x = self.emb(z, rbf, i, j) P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) # Interaction blocks. for interaction_block, output_block in zip(self.interaction_blocks, self.output_blocks[1:]): x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) P = P + output_block(x, rbf, i, num_nodes=pos.size(0)) if batch is None: return P.sum(dim=0) else: return scatter(P, batch, dim=0, reduce='sum')
[docs]class DimeNetPlusPlus(DimeNet): r"""The DimeNet++ from the `"Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules" <https://arxiv.org/abs/2011.14115>`_ paper. :class:`DimeNetPlusPlus` is an upgrade to the :class:`DimeNet` model with 8x faster and 10% more accurate than :class:`DimeNet`. Args: hidden_channels (int): Hidden embedding size. out_channels (int): Size of each output sample. num_blocks (int): Number of building blocks. int_emb_size (int): Size of embedding in the interaction block. basis_emb_size (int): Size of basis embedding in the interaction block. out_emb_channels (int): Size of embedding in the output block. num_spherical (int): Number of spherical harmonics. num_radial (int): Number of radial basis functions. cutoff: (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`5.0`) max_num_neighbors (int, optional): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: :obj:`32`) envelope_exponent (int, optional): Shape of the smooth cutoff. (default: :obj:`5`) num_before_skip: (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: :obj:`1`) num_after_skip: (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: :obj:`2`) num_output_layers: (int, optional): Number of linear layers for the output blocks. (default: :obj:`3`) act: (str or Callable, optional): The activation funtion. (default: :obj:`"swish"`) output_initializer (str, optional): The initialization method for the output layer (:obj:`"zeros"`, :obj:`"glorot_orthogonal"`). (default: :obj:`"zeros"`) """ url = ('https://raw.githubusercontent.com/gasteigerjo/dimenet/' 'master/pretrained/dimenet_pp') def __init__( self, hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros', ): act = activation_resolver(act) super().__init__( hidden_channels=hidden_channels, out_channels=out_channels, num_blocks=num_blocks, num_bilinear=1, num_spherical=num_spherical, num_radial=num_radial, cutoff=cutoff, max_num_neighbors=max_num_neighbors, envelope_exponent=envelope_exponent, num_before_skip=num_before_skip, num_after_skip=num_after_skip, num_output_layers=num_output_layers, act=act, output_initializer=output_initializer, ) # We are re-using the RBF, SBF and embedding layers of `DimeNet` and # redefine output_block and interaction_block in DimeNet++. # Hence, it is to be noted that in the above initalization, the # variable `num_bilinear` does not have any purpose as it is used # solely in the `OutputBlock` of DimeNet: self.output_blocks = torch.nn.ModuleList([ OutputPPBlock( num_radial, hidden_channels, out_emb_channels, out_channels, num_output_layers, act, output_initializer, ) for _ in range(num_blocks + 1) ]) self.interaction_blocks = torch.nn.ModuleList([ InteractionPPBlock( hidden_channels, int_emb_size, basis_emb_size, num_spherical, num_radial, num_before_skip, num_after_skip, act, ) for _ in range(num_blocks) ]) self.reset_parameters()
[docs] @classmethod def from_qm9_pretrained( cls, root: str, dataset: Dataset, target: int, ) -> Tuple['DimeNetPlusPlus', Dataset, Dataset, Dataset]: # pragma: no cover r"""Returns a pre-trained :class:`DimeNetPlusPlus` model on the :class:`~torch_geometric.datasets.QM9` dataset, trained on the specified target :obj:`target`. """ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf assert target >= 0 and target <= 12 and not target == 4 root = osp.expanduser(osp.normpath(root)) path = osp.join(root, 'pretrained_dimenet_pp', qm9_target_dict[target]) os.makedirs(path, exist_ok=True) url = f'{cls.url}/{qm9_target_dict[target]}' if not osp.exists(osp.join(path, 'checkpoint')): download_url(f'{url}/checkpoint', path) download_url(f'{url}/ckpt.data-00000-of-00002', path) download_url(f'{url}/ckpt.data-00001-of-00002', path) download_url(f'{url}/ckpt.index', path) path = osp.join(path, 'ckpt') reader = tf.train.load_checkpoint(path) # Configuration from DimeNet++: # https://github.com/gasteigerjo/dimenet/blob/master/config_pp.yaml model = cls( hidden_channels=128, out_channels=1, num_blocks=4, int_emb_size=64, basis_emb_size=8, out_emb_channels=256, num_spherical=7, num_radial=6, cutoff=5.0, max_num_neighbors=32, envelope_exponent=5, num_before_skip=1, num_after_skip=2, num_output_layers=3, ) def copy_(src, name, transpose=False): init = reader.get_tensor(f'{name}/.ATTRIBUTES/VARIABLE_VALUE') init = torch.from_numpy(init) if name[-6:] == 'kernel': init = init.t() src.data.copy_(init) copy_(model.rbf.freq, 'rbf_layer/frequencies') copy_(model.emb.emb.weight, 'emb_block/embeddings') copy_(model.emb.lin_rbf.weight, 'emb_block/dense_rbf/kernel') copy_(model.emb.lin_rbf.bias, 'emb_block/dense_rbf/bias') copy_(model.emb.lin.weight, 'emb_block/dense/kernel') copy_(model.emb.lin.bias, 'emb_block/dense/bias') for i, block in enumerate(model.output_blocks): copy_(block.lin_rbf.weight, f'output_blocks/{i}/dense_rbf/kernel') copy_(block.lin_up.weight, f'output_blocks/{i}/up_projection/kernel') for j, lin in enumerate(block.lins): copy_(lin.weight, f'output_blocks/{i}/dense_layers/{j}/kernel') copy_(lin.bias, f'output_blocks/{i}/dense_layers/{j}/bias') copy_(block.lin.weight, f'output_blocks/{i}/dense_final/kernel') for i, block in enumerate(model.interaction_blocks): copy_(block.lin_rbf1.weight, f'int_blocks/{i}/dense_rbf1/kernel') copy_(block.lin_rbf2.weight, f'int_blocks/{i}/dense_rbf2/kernel') copy_(block.lin_sbf1.weight, f'int_blocks/{i}/dense_sbf1/kernel') copy_(block.lin_sbf2.weight, f'int_blocks/{i}/dense_sbf2/kernel') copy_(block.lin_ji.weight, f'int_blocks/{i}/dense_ji/kernel') copy_(block.lin_ji.bias, f'int_blocks/{i}/dense_ji/bias') copy_(block.lin_kj.weight, f'int_blocks/{i}/dense_kj/kernel') copy_(block.lin_kj.bias, f'int_blocks/{i}/dense_kj/bias') copy_(block.lin_down.weight, f'int_blocks/{i}/down_projection/kernel') copy_(block.lin_up.weight, f'int_blocks/{i}/up_projection/kernel') for j, layer in enumerate(block.layers_before_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_before_skip/{j}/dense_2/bias') copy_(block.lin.weight, f'int_blocks/{i}/final_before_skip/kernel') copy_(block.lin.bias, f'int_blocks/{i}/final_before_skip/bias') for j, layer in enumerate(block.layers_after_skip): copy_(layer.lin1.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/kernel') copy_(layer.lin1.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_1/bias') copy_(layer.lin2.weight, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/kernel') copy_(layer.lin2.bias, f'int_blocks/{i}/layers_after_skip/{j}/dense_2/bias') random_state = np.random.RandomState(seed=42) perm = torch.from_numpy(random_state.permutation(np.arange(130831))) perm = perm.long() train_idx = perm[:110000] val_idx = perm[110000:120000] test_idx = perm[120000:] return model, (dataset[train_idx], dataset[val_idx], dataset[test_idx])