pytorch系列3:在预训练好的模型上进行微调

在训练网络时经常要进行微调,原来用caffe比较多,相对来说caffe的微调要简单,那么pytorch是怎么实现网络的微调呢,我的方法是:

# 导入原来的网络
import modules_original
# 导入新的网络
import modules


pre_net = modules_original()     # 
pre_net.load_state_dict(torch.load('pre_model.pth'))

new_net = modules.net()
pre_dict = pre_net.state_dict()
new_dict = new_net.state_dict()

pre_dict = {k: v for k, v in pre_dict.items() if k in new_dict}

new_dict.update(pre_dict)
new_net.load_state_dict(new_dict)

如果新的网络和原来的网络一样的话参考上面的代码就可以了,如果中间修改了某些层,可以对不同的层设置不同的学习率,比如新的层的学习率大一些,实现方法参考上一篇博客。

你可能感兴趣的:(python,深度学习,Pytorch)