from typing import Optional
from torch_geometric.typing import Adj, OptTensor
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear, Parameter, GRUCell
from torch_geometric.utils import softmax
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from ..inits import glorot, zeros
class GATEConv(MessagePassing):
def __init__(self, in_channels: int, out_channels: int, edge_dim: int,
dropout: float = 0.0):
super(GATEConv, self).__init__(aggr='add', node_dim=0)
self.dropout = dropout
self.att_l = Parameter(torch.Tensor(1, out_channels))
self.att_r = Parameter(torch.Tensor(1, in_channels))
self.lin1 = Linear(in_channels + edge_dim, out_channels, False)
self.lin2 = Linear(out_channels, out_channels, False)
self.bias = Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
glorot(self.att_l)
glorot(self.att_r)
glorot(self.lin1.weight)
glorot(self.lin2.weight)
zeros(self.bias)
def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor:
out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
out += self.bias
return out
def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1)))
alpha_j = (x_j * self.att_l).sum(dim=-1)
alpha_i = (x_i * self.att_r).sum(dim=-1)
alpha = alpha_j + alpha_i
alpha = F.leaky_relu_(alpha)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return self.lin2(x_j) * alpha.unsqueeze(-1)
[docs]class AttentiveFP(torch.nn.Module):
r"""The Attentive FP model for molecular representation learning from the
`"Pushing the Boundaries of Molecular Representation for Drug Discovery
with the Graph Attention Mechanism"
<https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on
graph attention mechanisms.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Hidden node feature dimensionality.
out_channels (int): Size of each output sample.
edge_dim (int): Edge feature dimensionality.
num_layers (int): Number of GNN layers.
num_timesteps (int): Number of iterative refinement steps for global
readout.
dropout (float, optional): Dropout probability. (default: :obj:`0.0`)
"""
def __init__(self, in_channels: int, hidden_channels: int,
out_channels: int, edge_dim: int, num_layers: int,
num_timesteps: int, dropout: float = 0.0):
super(AttentiveFP, self).__init__()
self.num_layers = num_layers
self.num_timesteps = num_timesteps
self.dropout = dropout
self.lin1 = Linear(in_channels, hidden_channels)
conv = GATEConv(hidden_channels, hidden_channels, edge_dim, dropout)
gru = GRUCell(hidden_channels, hidden_channels)
self.atom_convs = torch.nn.ModuleList([conv])
self.atom_grus = torch.nn.ModuleList([gru])
for _ in range(num_layers - 1):
conv = GATConv(hidden_channels, hidden_channels, dropout=dropout,
add_self_loops=False, negative_slope=0.01)
self.atom_convs.append(conv)
self.atom_grus.append(GRUCell(hidden_channels, hidden_channels))
self.mol_conv = GATConv(hidden_channels, hidden_channels,
dropout=dropout, add_self_loops=False,
negative_slope=0.01)
self.mol_gru = GRUCell(hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, out_channels)
self.reset_parameters()
[docs] def reset_parameters(self):
self.lin1.reset_parameters()
for conv, gru in zip(self.atom_convs, self.atom_grus):
conv.reset_parameters()
gru.reset_parameters()
self.mol_conv.reset_parameters()
self.mol_gru.reset_parameters()
self.lin2.reset_parameters()
[docs] def forward(self, x, edge_index, edge_attr, batch):
""""""
# Atom Embedding:
x = F.leaky_relu_(self.lin1(x))
h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
h = F.dropout(h, p=self.dropout, training=self.training)
x = self.atom_grus[0](h, x).relu_()
for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
h = F.elu_(conv(x, edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
x = gru(h, x).relu_()
# Molecule Embedding:
row = torch.arange(batch.size(0), device=batch.device)
edge_index = torch.stack([row, batch], dim=0)
out = global_add_pool(x, batch).relu_()
for t in range(self.num_timesteps):
h = F.elu_(self.mol_conv((x, out), edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
out = self.mol_gru(h, out).relu_()
# Predictor:
out = F.dropout(out, p=self.dropout, training=self.training)
return self.lin2(out)