pytorch 如何以最方便的方式固定前置网络(特征提取网络)的参数

 

已经训练好了cnn1,现在想要用cnn1输出的结果来训练cnn2,但是要完全 固定cnn1的参数:

self.step1=cnn1()
for p in cnn1.parameters():
    p.requires_grad = False
self.init_=torch.load('cnn1.weight.path.pth.tar')
self.step1.load_state_dict(self.init_['state_dict'])

self.cnn2=~~~~
self.cnn3=~~~~~


def forward(x)
    x=self.step1(x)
    x=self.cnn2(x)
    return self.cnn3(x)

 

 

你可能感兴趣的:(pytorch,python)