net.apply(init_weights)

import torch.nn as nn
import torch
import numpy as np
def init_weights(m):
         print('m',m)
         if type(m) == nn.Linear:
             m.weight.data.fill_(1.0)
             print('mw',m.weight)
            
net = nn.Sequential(nn.Linear(3, 2,bias=False), nn.Linear(2, 3,bias=False))
aa=np.array([2.,3.,4.]).reshape(1,3).astype(np.float32)

bb=torch.from_numpy(aa)

net.apply(init_weights)
c=net(bb)
d=c.detach().numpy() 

net.apply(init_weights)_第1张图片

The apply function will search recursively for all the modules inside your network, and will call the function on each of them. So all Linear layers you have in your model will be initialized using this one call. 

你可能感兴趣的:(net.apply(init_weights))