Source code for torch_geometric.nn.models.neural_fingerprint

from typing import Optional

import torch
from torch import Tensor

from torch_geometric.nn import Linear, MFConv, global_add_pool
from torch_geometric.typing import Adj


[docs]class NeuralFingerprint(torch.nn.Module): r"""The Neural Fingerprint model from the `"Convolutional Networks on Graphs for Learning Molecular Fingerprints" <https://arxiv.org/abs/1509.09292>`__ paper to generate fingerprints of molecules. Args: in_channels (int): Size of each input sample. hidden_channels (int): Size of each hidden sample. out_channels (int): Size of each output fingerprint. num_layers (int): Number of layers. **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MFConv`. """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int, **kwargs, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.num_layers = num_layers self.convs = torch.nn.ModuleList() for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.hidden_channels self.convs.append(MFConv(in_channels, hidden_channels, **kwargs)) self.lins = torch.nn.ModuleList() for _ in range(self.num_layers): self.lins.append(Linear(hidden_channels, out_channels, bias=False))
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" for conv in self.convs: conv.reset_parameters() for lin in self.lins: lin.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: Adj, batch: Optional[Tensor] = None, batch_size: Optional[int] = None, ) -> Tensor: """""" # noqa: D419 outs = [] for conv, lin in zip(self.convs, self.lins): x = conv(x, edge_index).sigmoid() y = lin(x).softmax(dim=-1) outs.append(global_add_pool(y, batch, batch_size)) return sum(outs)
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, num_layers={self.num_layers})')