Source code for torch_geometric.graphgym.init
import torch
[docs]def init_weights(m):
r"""Performs weight initialization.
Args:
m (nn.Module): PyTorch module
"""
if (isinstance(m, torch.nn.BatchNorm2d)
or isinstance(m, torch.nn.BatchNorm1d)):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, torch.nn.Linear):
m.weight.data = torch.nn.init.xavier_uniform_(
m.weight.data, gain=torch.nn.init.calculate_gain('relu'))
if m.bias is not None:
m.bias.data.zero_()