使用PyTorch冻结模型参数的方法

前言

在深度学习领域,经常需要使用其他人已训练好的模型进行改进或微调,这个时候我们会加载已有的预训练模型文件的参数,如果网络结构不变,希望使用新数据微调部分网络参数。这时我们则需要冻结部分参数,禁止其更新。

在这里插入图片描述

方法

(1)通过遍历网络结构,设置梯度更新requires_grad = False。

 # 冻结network1的全部参数和network2的部分参数
 for name, parameter in network1.named_parameters():
     parameter.requires_grad = False

 for name, parameter in network2.named_parameters():
     if 'key' in name:
         parameter.requires_grad = False

(2)优化器中过滤filter冻结的参数

optimizer_network2 = torch.optim.Adam(filter(lambda p: p.requires_grad, network2.parameters()), lr=0.005, betas=(0.5, 0.999))

其他

结合加载模型部分参数的情况,优化器需要按如下设置:

   optimizer_network2 = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, network2.parameters()), 'initial_lr': 0.0002}], lr=0.005, betas=(0.5, 0.999))

使用PyTorch冻结模型参数的方法_第1张图片

参考资料

[1] csdn - 使用PyTorch加载模型部分参数方法
[2] 知乎 - Pytorch自由载入部分模型参数并冻结

你可能感兴趣的:(#,深度学习框架,#,Python,人工智能,深度学习,PyTorch)