from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import GRUCell, Linear, Parameter
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import softmax
class GATEConv(MessagePassing):
def __init__(
self,
in_channels: int,
out_channels: int,
edge_dim: int,
dropout: float = 0.0,
):
super().__init__(aggr='add', node_dim=0)
self.dropout = dropout
self.att_l = Parameter(torch.empty(1, out_channels))
self.att_r = Parameter(torch.empty(1, in_channels))
self.lin1 = Linear(in_channels + edge_dim, out_channels, False)
self.lin2 = Linear(out_channels, out_channels, False)
self.bias = Parameter(torch.empty(out_channels))
self.reset_parameters()
def reset_parameters(self):
glorot(self.att_l)
glorot(self.att_r)
glorot(self.lin1.weight)
glorot(self.lin2.weight)
zeros(self.bias)
def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor:
# edge_updater_type: (x: Tensor, edge_attr: Tensor)
alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr)
# propagate_type: (x: Tensor, alpha: Tensor)
out = self.propagate(edge_index, x=x, alpha=alpha)
out = out + self.bias
return out
def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1)))
alpha_j = (x_j @ self.att_l.t()).squeeze(-1)
alpha_i = (x_i @ self.att_r.t()).squeeze(-1)
alpha = alpha_j + alpha_i
alpha = F.leaky_relu_(alpha)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return alpha
def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
return self.lin2(x_j) * alpha.unsqueeze(-1)
[docs]class AttentiveFP(torch.nn.Module):
r"""The Attentive FP model for molecular representation learning from the
`"Pushing the Boundaries of Molecular Representation for Drug Discovery
with the Graph Attention Mechanism"
<https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on
graph attention mechanisms.
Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Hidden node feature dimensionality.
out_channels (int): Size of each output sample.
edge_dim (int): Edge feature dimensionality.
num_layers (int): Number of GNN layers.
num_timesteps (int): Number of iterative refinement steps for global
readout.
dropout (float, optional): Dropout probability. (default: :obj:`0.0`)
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
edge_dim: int,
num_layers: int,
num_timesteps: int,
dropout: float = 0.0,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.edge_dim = edge_dim
self.num_layers = num_layers
self.num_timesteps = num_timesteps
self.dropout = dropout
self.lin1 = Linear(in_channels, hidden_channels)
self.gate_conv = GATEConv(hidden_channels, hidden_channels, edge_dim,
dropout)
self.gru = GRUCell(hidden_channels, hidden_channels)
self.atom_convs = torch.nn.ModuleList()
self.atom_grus = torch.nn.ModuleList()
for _ in range(num_layers - 1):
conv = GATConv(hidden_channels, hidden_channels, dropout=dropout,
add_self_loops=False, negative_slope=0.01)
self.atom_convs.append(conv)
self.atom_grus.append(GRUCell(hidden_channels, hidden_channels))
self.mol_conv = GATConv(hidden_channels, hidden_channels,
dropout=dropout, add_self_loops=False,
negative_slope=0.01)
self.mol_conv.explain = False # Cannot explain global pooling.
self.mol_gru = GRUCell(hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, out_channels)
self.reset_parameters()
[docs] def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.lin1.reset_parameters()
self.gate_conv.reset_parameters()
self.gru.reset_parameters()
for conv, gru in zip(self.atom_convs, self.atom_grus):
conv.reset_parameters()
gru.reset_parameters()
self.mol_conv.reset_parameters()
self.mol_gru.reset_parameters()
self.lin2.reset_parameters()
[docs] def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor,
batch: Tensor) -> Tensor:
"""""" # noqa: D419
# Atom Embedding:
x = F.leaky_relu_(self.lin1(x))
h = F.elu_(self.gate_conv(x, edge_index, edge_attr))
h = F.dropout(h, p=self.dropout, training=self.training)
x = self.gru(h, x).relu_()
for conv, gru in zip(self.atom_convs, self.atom_grus):
h = conv(x, edge_index)
h = F.elu(h)
h = F.dropout(h, p=self.dropout, training=self.training)
x = gru(h, x).relu()
# Molecule Embedding:
row = torch.arange(batch.size(0), device=batch.device)
edge_index = torch.stack([row, batch], dim=0)
out = global_add_pool(x, batch).relu_()
for t in range(self.num_timesteps):
h = F.elu_(self.mol_conv((x, out), edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
out = self.mol_gru(h, out).relu_()
# Predictor:
out = F.dropout(out, p=self.dropout, training=self.training)
return self.lin2(out)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'in_channels={self.in_channels}, '
f'hidden_channels={self.hidden_channels}, '
f'out_channels={self.out_channels}, '
f'edge_dim={self.edge_dim}, '
f'num_layers={self.num_layers}, '
f'num_timesteps={self.num_timesteps}'
f')')