Source code for torch_geometric.nn.models.attentive_fp

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import GRUCell, Linear, Parameter

from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import softmax


class GATEConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        edge_dim: int,
        dropout: float = 0.0,
    ):
        super().__init__(aggr='add', node_dim=0)

        self.dropout = dropout

        self.att_l = Parameter(torch.empty(1, out_channels))
        self.att_r = Parameter(torch.empty(1, in_channels))

        self.lin1 = Linear(in_channels + edge_dim, out_channels, False)
        self.lin2 = Linear(out_channels, out_channels, False)

        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.att_l)
        glorot(self.att_r)
        glorot(self.lin1.weight)
        glorot(self.lin2.weight)
        zeros(self.bias)

    def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor:
        # edge_updater_type: (x: Tensor, edge_attr: Tensor)
        alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

        # propagate_type: (x: Tensor, alpha: Tensor)
        out = self.propagate(edge_index, x=x, alpha=alpha)
        out = out + self.bias
        return out

    def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,
                    index: Tensor, ptr: OptTensor,
                    size_i: Optional[int]) -> Tensor:
        x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1)))
        alpha_j = (x_j @ self.att_l.t()).squeeze(-1)
        alpha_i = (x_i @ self.att_r.t()).squeeze(-1)
        alpha = alpha_j + alpha_i
        alpha = F.leaky_relu_(alpha)
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return alpha

    def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
        return self.lin2(x_j) * alpha.unsqueeze(-1)


[docs]class AttentiveFP(torch.nn.Module): r"""The Attentive FP model for molecular representation learning from the `"Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism" <https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on graph attention mechanisms. Args: in_channels (int): Size of each input sample. hidden_channels (int): Hidden node feature dimensionality. out_channels (int): Size of each output sample. edge_dim (int): Edge feature dimensionality. num_layers (int): Number of GNN layers. num_timesteps (int): Number of iterative refinement steps for global readout. dropout (float, optional): Dropout probability. (default: :obj:`0.0`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, edge_dim: int, num_layers: int, num_timesteps: int, dropout: float = 0.0, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.edge_dim = edge_dim self.num_layers = num_layers self.num_timesteps = num_timesteps self.dropout = dropout self.lin1 = Linear(in_channels, hidden_channels) self.gate_conv = GATEConv(hidden_channels, hidden_channels, edge_dim, dropout) self.gru = GRUCell(hidden_channels, hidden_channels) self.atom_convs = torch.nn.ModuleList() self.atom_grus = torch.nn.ModuleList() for _ in range(num_layers - 1): conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.atom_convs.append(conv) self.atom_grus.append(GRUCell(hidden_channels, hidden_channels)) self.mol_conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.mol_conv.explain = False # Cannot explain global pooling. self.mol_gru = GRUCell(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) self.reset_parameters()
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin1.reset_parameters() self.gate_conv.reset_parameters() self.gru.reset_parameters() for conv, gru in zip(self.atom_convs, self.atom_grus): conv.reset_parameters() gru.reset_parameters() self.mol_conv.reset_parameters() self.mol_gru.reset_parameters() self.lin2.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, batch: Tensor) -> Tensor: """""" # noqa: D419 # Atom Embedding: x = F.leaky_relu_(self.lin1(x)) h = F.elu_(self.gate_conv(x, edge_index, edge_attr)) h = F.dropout(h, p=self.dropout, training=self.training) x = self.gru(h, x).relu_() for conv, gru in zip(self.atom_convs, self.atom_grus): h = conv(x, edge_index) h = F.elu(h) h = F.dropout(h, p=self.dropout, training=self.training) x = gru(h, x).relu() # Molecule Embedding: row = torch.arange(batch.size(0), device=batch.device) edge_index = torch.stack([row, batch], dim=0) out = global_add_pool(x, batch).relu_() for t in range(self.num_timesteps): h = F.elu_(self.mol_conv((x, out), edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) out = self.mol_gru(h, out).relu_() # Predictor: out = F.dropout(out, p=self.dropout, training=self.training) return self.lin2(out)
def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'in_channels={self.in_channels}, ' f'hidden_channels={self.hidden_channels}, ' f'out_channels={self.out_channels}, ' f'edge_dim={self.edge_dim}, ' f'num_layers={self.num_layers}, ' f'num_timesteps={self.num_timesteps}' f')')