Pytorch网络模型参数修改

       最近在研究HRank的论文和代码,发现里面的模型结构化剪枝后,模型大小没有变小,打印最终训练完后的网络模型中的参数,发现它只是把一些channel的参数全部稀疏化为0,并没有把channel的个数裁剪下来。

       于是想自己写代码重新构建裁剪channel后的小网络,然后把通道稀疏化(也就是裁剪)后的网络参数直接给到小网络中,这样做就可以用小网络模型做一模一样的推理了。关于为什么可以直接用小网络获取大网络稀疏化通道裁剪后的网络参数作推理,其实原因很简单的,后续再写HRank论文的总结的博文时,我会画图详细说明。

       现在就是有一个大的网络模型,然后训练好了,需要把每层卷积中不为0的通道赋值给另外一个小网络中,这种就遇到Pytorch网络模型参数修改的问题。在网络通过搜索和学习后,终于知道

       需要先把 model.state_dict()给到model_dict ,然后再修改model_dict ,最后再把model_dict通过model.load_state_dict(model_dict) 加载进模型中去,这样才能成功修改网络模型的参数,比如如果有两个网络net和model,net已经通过net.load_state_dict加载过参数,现在想通过手动把net的参数给到一模一样的model中,  参考代码如下:                  

net_dict = net.state_dict()  //net是已经通过net.load_state_dict加载过参数的模型
model_dict  =   model.state_dict()  //model是跟net一样的网络,但是没有加载模型参数
for par1, par2 in zip(net_dict,model_dict):
    model_dict[par2] = net_dict[par1]    //这里赋值,也可以在这里修改model_dict的网络参数的值
model.load_state_dict(model_dict)        //重新load下,这里model的参数就跟net中一样了

        之所以要这样修改,是因为load_state_dict函数中调用model_state_dict()函数返回的只是module类内部state dict对象的一个copy,所以修改里面的值并不能影响模型中真正的参数。

     

你可能感兴趣的:(深度学习)