pytorch学习007- -预训练中的权重加载(完全导入,部分导入)

文章目录

  • 更新
  • 问题
  • 方案
    • PyTorch文档
    • 模型对应,完全导入
    • 模型不完全对应
      • 只有部分对应
      • A属于B
      • B属于A

更新

2022.04.12更新
导入权重的用法相当普遍,但是可以导入吗?导入有什么影响?
首先一定是可以导入的,但是导入之后是否有效果?那应该分以下情况讨论。

  • 网络模型完全对应:这种情况可以导入,而且微调效果更好
  • 网络模型不完全对应(小心这种情况)
    • 只是输出层有部分变化,可以导入
    • 中间层有变化,不建议导入

问题

  1. 预训练后的权重如何导入另一个网络模型?
  2. 预训练对应的网络模型A与未训练的网络模型结构B不对应?
    2.1 两个网络模型A和B只有部分对应
    2.2 集合关系上A属于B
    2.3 集合关系上B属于A

方案

PyTorch文档

  • torch.nn.modules.module.Module def load_state_dict(self,
    state_dict: Dict[str, Tensor] | OrderedDict[str, Tensor],
    strict: bool = …) -> None
  • 说明:将 state_dict 中的参数和缓冲区复制到此模块及其后代中。
    • 如果 strict 为 True,则 state_dict 的键必须与此模块的torch.nn.Module.state_dict 函数返回的键完全匹配
  • 参数
    state_dict – 包含参数和持久缓冲区的字典。
    strict – 是否严格强制:
    • attr:state_dict 中的键与该模块的 :meth:~torch.nn.Module.state_dict 函数返回的键匹配。 默认值:“真”
  • 返回值:
    • missing_keys 是包含缺失键的 str 列表
    • unexpected_keys 是包含意外键的 str 列表

模型对应,完全导入

# demo1 完全加载权重
model = NET1()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict'] #读取预训练模型权重
model.load_state_dict(weights)

模型不完全对应

此一种情况经常出现在要修改预训练网络模型中某些层时,可能增加若干层,可能减少若干层,或上述两种情况皆有。

只有部分对应

pytorch学习007- -预训练中的权重加载(完全导入,部分导入)_第1张图片
两个模型中有部分是对应的,此种情况建议使用PyTorch中的load_state_dict所提供的参数:strict
将strict设置为False,可以在两个模型不同的情况下,仅加载相同键值部分。(保证各层的名字相同)

# demo2
model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict']	#读取预训练模型权重
model.load_state_dict(weights, strict=False)	#strict

A属于B

pytorch学习007- -预训练中的权重加载(完全导入,部分导入)_第2张图片
此种情况常见于,在网上download别人的预训练模型后,需要根据自己的任务,添加若干个层,而其他层保持不变。

# demo3
*****待测试

B属于A

pytorch学习007- -预训练中的权重加载(完全导入,部分导入)_第3张图片
此种情况常见于从网上download别人的预训练模型后,因为某些限制,需要对模型进行精简,只删除若干个层,其他层保持不变。

# demo4
*****待测试

你可能感兴趣的:(深度学习,pytorch,深度学习)