Source code for torch_geometric.nn.models.attentive_fp

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import GRUCell, Linear, Parameter

from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import softmax

from ..inits import glorot, zeros


class GATEConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        edge_dim: int,
        dropout: float = 0.0,
    ):
        super().__init__(aggr='add', node_dim=0)

        self.dropout = dropout

        self.att_l = Parameter(torch.Tensor(1, out_channels))
        self.att_r = Parameter(torch.Tensor(1, in_channels))

        self.lin1 = Linear(in_channels + edge_dim, out_channels, False)
        self.lin2 = Linear(out_channels, out_channels, False)

        self.bias = Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.att_l)
        glorot(self.att_r)
        glorot(self.lin1.weight)
        glorot(self.lin2.weight)
        zeros(self.bias)

    def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor:
        # propagate_type: (x: Tensor, edge_attr: Tensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
        out = out + self.bias
        return out

    def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,
                index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:

        x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1)))
        alpha_j = (x_j * self.att_l).sum(dim=-1)
        alpha_i = (x_i * self.att_r).sum(dim=-1)
        alpha = alpha_j + alpha_i
        alpha = F.leaky_relu_(alpha)
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return self.lin2(x_j) * alpha.unsqueeze(-1)


[docs]class AttentiveFP(torch.nn.Module): r"""The Attentive FP model for molecular representation learning from the `"Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism" <https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on graph attention mechanisms. Args: in_channels (int): Size of each input sample. hidden_channels (int): Hidden node feature dimensionality. out_channels (int): Size of each output sample. edge_dim (int): Edge feature dimensionality. num_layers (int): Number of GNN layers. num_timesteps (int): Number of iterative refinement steps for global readout. dropout (float, optional): Dropout probability. (default: :obj:`0.0`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, edge_dim: int, num_layers: int, num_timesteps: int, dropout: float = 0.0, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.edge_dim = edge_dim self.num_layers = num_layers self.num_timesteps = num_timesteps self.dropout = dropout self.lin1 = Linear(in_channels, hidden_channels) self.gate_conv = GATEConv(hidden_channels, hidden_channels, edge_dim, dropout) self.gru = GRUCell(hidden_channels, hidden_channels) self.atom_convs = torch.nn.ModuleList() self.atom_grus = torch.nn.ModuleList() for _ in range(num_layers - 1): conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.atom_convs.append(conv) self.atom_grus.append(GRUCell(hidden_channels, hidden_channels)) self.mol_conv = GATConv(hidden_channels, hidden_channels, dropout=dropout, add_self_loops=False, negative_slope=0.01) self.mol_gru = GRUCell(hidden_channels, hidden_channels) self.lin2 = Linear(hidden_channels, out_channels) self.reset_parameters()
[docs] def reset_parameters(self) -> None: self.lin1.reset_parameters() self.gate_conv.reset_parameters() self.gru.reset_parameters() for conv, gru in zip(self.atom_convs, self.atom_grus): conv.reset_parameters() gru.reset_parameters() self.mol_conv.reset_parameters() self.mol_gru.reset_parameters() self.lin2.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, batch: Tensor) -> Tensor: """""" # Atom Embedding: x = F.leaky_relu_(self.lin1(x)) h = F.elu_(self.gate_conv(x, edge_index, edge_attr)) h = F.dropout(h, p=self.dropout, training=self.training) x = self.gru(h, x).relu_() for conv, gru in zip(self.atom_convs, self.atom_grus): h = F.elu_(conv(x, edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) x = gru(h, x).relu_() # Molecule Embedding: row = torch.arange(batch.size(0), device=batch.device) edge_index = torch.stack([row, batch], dim=0) out = global_add_pool(x, batch).relu_() for t in range(self.num_timesteps): h = F.elu_(self.mol_conv((x, out), edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) out = self.mol_gru(h, out).relu_() # Predictor: out = F.dropout(out, p=self.dropout, training=self.training) return self.lin2(out)
[docs] def jittable(self) -> 'AttentiveFP': self.gate_conv = self.gate_conv.jittable() self.atom_convs = torch.nn.ModuleList( [conv.jittable() for conv in self.atom_convs]) self.mol_conv = self.mol_conv.jittable() return self
def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'in_channels={self.in_channels}, ' f'hidden_channels={self.hidden_channels}, ' f'out_channels={self.out_channels}, ' f'edge_dim={self.edge_dim}, ' f'num_layers={self.num_layers}, ' f'num_timesteps={self.num_timesteps}' f')')