Pytorch初始化网络参数

单层

To initialize the weights of a single layer, use a function from torch.nn.init. For instance:

conv1 = torch.nn.Conv2d(…)
torch.nn.init.xavier_uniform(conv1.weight)
Alternatively, you can modify the parameters by writing to conv1.weight.data (which is a torch.Tensor). Example:

conv1.weight.data.fill_(0.01)
The same applies for biases:

conv1.bias.data.fill_(0.01)

nn.Sequential或者自定义的nn.Module

Pass an initialization function to torch.nn.Module.apply. It will initialize the weights in the entire nn.Module recursively.

apply(fn): Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also torch-nn-init).
Example:

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)

你可能感兴趣的:(pytorch)