Deeplearning/Pytorch,中只导入部分层权重的方法

1. 问题描述:

现有已训练断点保存的权重,但是网络某处修改,导致现有权重与模型网络参数 无法完全匹配:

#原始版本
    # checkpoint = torch.load(best_train_weight_checkpoint_path_from_BCE)
    # net.load_state_dict(checkpoint['net'])

Deeplearning/Pytorch,中只导入部分层权重的方法_第1张图片

 正确做法:

Deeplearning/Pytorch,中只导入部分层权重的方法_第2张图片

 

#备注

循环取出pretrain_net['net'].items() 里面的键值对,如果k 在net中 那就取出来保存在 stat_dict中

然后再net_dict中更新原本pretrain中的内容,这样就将添加了原本pretrain中缺少的ape参数 。

总结:

        1. 

  1. 基本步骤:
    1. 取出pretrain中 在netdict中有的参数
    2. 然后用netdict update
  2. 如果pretrain中多了:还是一样的道理, 一样的步骤

你可能感兴趣的:(pytorch,深度学习,人工智能)