参考链接: torch.nn.Module.load_state_dict(state_dict, strict=True)
更多关于保存和加载模型的信息请参考这个链接: 讨论PyTorch中模型加载时参数不一致的情况
原文及翻译:
load_state_dict(state_dict, strict=True)
方法: load_state_dict(state_dict, strict=True)
Copies parameters and buffers from state_dict into this module
and its descendants. If strict is True, then the keys of
state_dict must exactly match the keys returned by this
module’s state_dict() function.
从函数接收的参数state_dict中将参数和缓冲拷贝到当前这个模块及其子模块中.
如果函数接受的参数strict是True,那么state_dict的关键字必须确切地严格地和
该模块的state_dict()函数返回的关键字相匹配.
Parameters 参数
state_dict (dict) – a dict containing parameters and persistent buffers.
state_dict (字典类型) – 一个包含参数和持续性缓冲的字典.
strict (bool, optional) – whether to strictly enforce that
the keys in state_dict match the keys returned by this
module’s state_dict() function. Default: True
strict (布尔类型, 可选) – 该参数用来指明是否需要强制严格匹配,
即:state_dict中的关键字是否需要和该模块的state_dict()方法返回
的关键字强制严格匹配.默认值是True.
Returns 返回
missing_keys is a list of str containing the missing keys
missing_keys是一个字符串的列表,该列表包含了所有缺失的关键字.
unexpected_keys is a list of str containing the unexpected keys
unexpected_keys是一个字符串的列表,该列表包含了意料之外的关键字,
即:多余的关键字(译者注).
Return type 返回类型
NamedTuple with missing_keys and unexpected_keys fields
具名元组,该具名元组包含了两个字段,分别是missing_keys 和unexpected_keys.
实验代码展示:
import torch
import torch.nn as nn
torch.manual_seed(seed=20200910)
class Model(torch.nn.Module):
def __init__(self):
super(Model,self).__init__()
self.conv1=torch.nn.Sequential( # 输入torch.Size([64, 1, 28, 28])
torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
torch.nn.ReLU(), # 输出torch.Size([64, 64, 28, 28])
torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1), # 输出torch.Size([64, 128, 28, 28])
torch.nn.ReLU(),
torch.nn.MaxPool2d(stride=2,kernel_size=2) # 输出torch.Size([64, 128, 14, 14])
)
self.dense=torch.nn.Sequential( # 输入torch.Size([64, 14*14*128])
torch.nn.Linear(14*14*128,1024), # 输出torch.Size([64, 1024])
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(1024,10) # 输出torch.Size([64, 10])
)
self.layer4cxq1 = torch.nn.Conv2d(2,33,4,4)
self.layer4cxq2 = torch.nn.ReLU()
self.layer4cxq3 = torch.nn.MaxPool2d(stride=2,kernel_size=2)
self.layer4cxq4 = torch.nn.Linear(14*14*128,1024)
self.layer4cxq5 = torch.nn.Dropout(p=0.8)
self.attribute4cxq = nn.Parameter(torch.tensor(20200910.0))
self.attribute4lzq = nn.Parameter(torch.tensor([2.0,3.0,4.0,5.0]))
self.attribute4hh = nn.Parameter(torch.randn(3,4,5,6))
self.attribute4wyf = nn.Parameter(torch.randn(7,8,9,10))
def forward(self,x): # torch.Size([64, 1, 28, 28])
x = self.conv1(x) # 输出torch.Size([64, 128, 14, 14])
x = x.view(-1,14*14*128) # torch.Size([64, 14*14*128])
x = self.dense(x) # 输出torch.Size([64, 10])
return x
print('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)
model = Model() #.cuda()
print("测试模型(CPU)".center(100,"-"))
print(type(model))
print("测试模型state_dict(destination=None, prefix='', keep_vars=False)方法".center(100,"-"))
myDict = model.state_dict()
print('函数返回的类型是:',type(myDict))
model_1 = Model()
NamedTuple_1 = model_1.load_state_dict(myDict, strict=True)
print('函数load_state_dict返回的类型是:',type(NamedTuple_1))
missing_keys, unexpected_keys = NamedTuple_1
print('missing_keys的类型:', type(missing_keys),'missing_keys:', missing_keys)
print('unexpected_keys的类型:', type(unexpected_keys),'unexpected_keys:', unexpected_keys)
print('测试当模型参数不匹配时的情况'.center(100,"-"))
model_2 = Model()
model_2.layer4cxq2020 =torch.nn.Sequential(torch.nn.Conv2d(1,7,kernel_size=3,stride=1,padding=1))
model_2.layer4cxq0910 = torch.nn.Linear(5,6)
model_2.attribute4cjh = nn.Parameter(torch.randn(17))
model_2.attribute4cxq_xxx = torch.tensor([2.0,3.0,4.0,5.0])
model_2.layer4cxq1 = None
myDict['attribute4cxq20200910'] = nn.Parameter(torch.tensor([2.0,3.0,4.0,5.0]))
myDict['attribute4cxq20200910ccc'] = torch.tensor([2.0,3.0,4.0,5.0])
del myDict['attribute4wyf']
del myDict['conv1.0.weight']
NamedTuple_2 = model_2.load_state_dict(myDict, strict=False) # strict=True
# 注意这里要strict=False,如果strict=True,那么就会报错
# 报错信息如下:
r'''
-------------------------------------------测试当模型参数不匹配时的情况-------------------------------------------
Traceback (most recent call last):
File "c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test31.py", line 74, in
NamedTuple_2 = model_2.load_state_dict(myDict, strict=True) # strict=True
File "D:\Anaconda3\envs\ssd4pytorch1_2_0\lib\site-packages\torch\nn\modules\module.py", line 845, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: "attribute4wyf", "attribute4cjh", "conv1.0.weight", "layer4cxq2020.0.weight", "layer4cxq2020.0.bias", "layer4cxq0910.weight", "layer4cxq0910.bias".
Unexpected key(s) in state_dict: "attribute4cxq20200910", "attribute4cxq20200910ccc".
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>
'''
print('函数load_state_dict返回的类型是:',type(NamedTuple_2))
missing_keys, unexpected_keys = NamedTuple_2
print('missing_keys的类型:', type(missing_keys),'\nmissing_keys:\n', missing_keys)
print('unexpected_keys的类型:', type(unexpected_keys),'\nunexpected_keys:\n', unexpected_keys)
print('model_2.layer4cxq1的取值是:',model_2.layer4cxq1)
控制台输出结果:
Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。
尝试新的跨平台 PowerShell https://aka.ms/pscore6
加载个人及系统配置文件用了 972 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '55348' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test31.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
---------------------------------------------测试模型(CPU)----------------------------------------------
-------------------测试模型state_dict(destination=None, prefix='', keep_vars=False)方法-------------------
函数返回的类型是:
函数load_state_dict返回的类型是:
missing_keys的类型: missing_keys: []
unexpected_keys的类型: unexpected_keys: []
-------------------------------------------测试当模型参数不匹配时的情况-------------------------------------------
函数load_state_dict返回的类型是:
missing_keys的类型:
missing_keys:
['attribute4wyf', 'attribute4cjh', 'conv1.0.weight', 'layer4cxq2020.0.weight', 'layer4cxq2020.0.bias', 'layer4cxq0910.weight', 'layer4cxq0910.bias']
unexpected_keys的类型:
unexpected_keys:
['attribute4cxq20200910', 'attribute4cxq20200910ccc']
model_2.layer4cxq1的取值是: None
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>