torch.nn.Module.load_state_dict(state_dict, strict=True)

参考链接: torch.nn.Module.load_state_dict(state_dict, strict=True)

更多关于保存和加载模型的信息请参考这个链接: 讨论PyTorch中模型加载时参数不一致的情况
torch.nn.Module.load_state_dict(state_dict, strict=True)_第1张图片

原文及翻译:

load_state_dict(state_dict, strict=True)
方法: load_state_dict(state_dict, strict=True)
    Copies parameters and buffers from state_dict into this module 
    and its descendants. If strict is True, then the keys of
    state_dict must exactly match the keys returned by this 
    module’s state_dict() function.
    从函数接收的参数state_dict中将参数和缓冲拷贝到当前这个模块及其子模块中.
    如果函数接受的参数strict是True,那么state_dict的关键字必须确切地严格地和
    该模块的state_dict()函数返回的关键字相匹配.

    Parameters  参数

            state_dict (dict) – a dict containing parameters and persistent buffers.
            state_dict (字典类型) – 一个包含参数和持续性缓冲的字典.
            strict (bool, optional) – whether to strictly enforce that
            the keys in state_dict match the keys returned by this 
            module’s state_dict() function. Default: True
            strict (布尔类型, 可选) – 该参数用来指明是否需要强制严格匹配,:state_dict中的关键字是否需要和该模块的state_dict()方法返回
            的关键字强制严格匹配.默认值是True.

    Returns  返回
            missing_keys is a list of str containing the missing keys
            missing_keys是一个字符串的列表,该列表包含了所有缺失的关键字.
            unexpected_keys is a list of str containing the unexpected keys
            unexpected_keys是一个字符串的列表,该列表包含了意料之外的关键字,:多余的关键字(译者注).
   
    Return type 返回类型
        NamedTuple with missing_keys and unexpected_keys fields
        具名元组,该具名元组包含了两个字段,分别是missing_keys 和unexpected_keys.

实验代码展示:

import torch 
import torch.nn as nn
torch.manual_seed(seed=20200910)
class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1=torch.nn.Sequential(  # 输入torch.Size([64, 1, 28, 28])
                torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
                torch.nn.ReLU(),  # 输出torch.Size([64, 64, 28, 28])
                torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),  # 输出torch.Size([64, 128, 28, 28])
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(stride=2,kernel_size=2)  # 输出torch.Size([64, 128, 14, 14])
        )

        self.dense=torch.nn.Sequential(  # 输入torch.Size([64, 14*14*128])
                    torch.nn.Linear(14*14*128,1024),  # 输出torch.Size([64, 1024])
                    torch.nn.ReLU(),
                    torch.nn.Dropout(p=0.5),
                    torch.nn.Linear(1024,10)  # 输出torch.Size([64, 10])        
        )
        self.layer4cxq1 = torch.nn.Conv2d(2,33,4,4)
        self.layer4cxq2 = torch.nn.ReLU()
        self.layer4cxq3 = torch.nn.MaxPool2d(stride=2,kernel_size=2)
        self.layer4cxq4 = torch.nn.Linear(14*14*128,1024)
        self.layer4cxq5 = torch.nn.Dropout(p=0.8)
        self.attribute4cxq = nn.Parameter(torch.tensor(20200910.0))
        self.attribute4lzq = nn.Parameter(torch.tensor([2.0,3.0,4.0,5.0]))    
        self.attribute4hh = nn.Parameter(torch.randn(3,4,5,6))
        self.attribute4wyf = nn.Parameter(torch.randn(7,8,9,10))

    def forward(self,x):  # torch.Size([64, 1, 28, 28])
        x = self.conv1(x)  # 输出torch.Size([64, 128, 14, 14])
        x = x.view(-1,14*14*128)  # torch.Size([64, 14*14*128])
        x = self.dense(x)  # 输出torch.Size([64, 10])
        return x

print('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)

model = Model() #.cuda()

print("测试模型(CPU)".center(100,"-"))
print(type(model))

print("测试模型state_dict(destination=None, prefix='', keep_vars=False)方法".center(100,"-"))

myDict = model.state_dict()
print('函数返回的类型是:',type(myDict))

model_1 = Model()



NamedTuple_1 = model_1.load_state_dict(myDict, strict=True)
print('函数load_state_dict返回的类型是:',type(NamedTuple_1))
missing_keys, unexpected_keys = NamedTuple_1

print('missing_keys的类型:', type(missing_keys),'missing_keys:', missing_keys)
print('unexpected_keys的类型:', type(unexpected_keys),'unexpected_keys:', unexpected_keys)


print('测试当模型参数不匹配时的情况'.center(100,"-"))
model_2 = Model()
model_2.layer4cxq2020 =torch.nn.Sequential(torch.nn.Conv2d(1,7,kernel_size=3,stride=1,padding=1))
model_2.layer4cxq0910 = torch.nn.Linear(5,6)
model_2.attribute4cjh = nn.Parameter(torch.randn(17))
model_2.attribute4cxq_xxx = torch.tensor([2.0,3.0,4.0,5.0])
model_2.layer4cxq1 = None
myDict['attribute4cxq20200910'] = nn.Parameter(torch.tensor([2.0,3.0,4.0,5.0]))
myDict['attribute4cxq20200910ccc'] = torch.tensor([2.0,3.0,4.0,5.0])
del myDict['attribute4wyf']
del myDict['conv1.0.weight']

NamedTuple_2 = model_2.load_state_dict(myDict, strict=False)  # strict=True
# 注意这里要strict=False,如果strict=True,那么就会报错
# 报错信息如下:
r'''
-------------------------------------------测试当模型参数不匹配时的情况-------------------------------------------
Traceback (most recent call last):
  File "c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test31.py", line 74, in 
    NamedTuple_2 = model_2.load_state_dict(myDict, strict=True)  # strict=True
  File "D:\Anaconda3\envs\ssd4pytorch1_2_0\lib\site-packages\torch\nn\modules\module.py", line 845, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Model:
        Missing key(s) in state_dict: "attribute4wyf", "attribute4cjh", "conv1.0.weight", "layer4cxq2020.0.weight", "layer4cxq2020.0.bias", "layer4cxq0910.weight", "layer4cxq0910.bias".
        Unexpected key(s) in state_dict: "attribute4cxq20200910", "attribute4cxq20200910ccc".
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 
'''

print('函数load_state_dict返回的类型是:',type(NamedTuple_2))
missing_keys, unexpected_keys = NamedTuple_2

print('missing_keys的类型:', type(missing_keys),'\nmissing_keys:\n', missing_keys)
print('unexpected_keys的类型:', type(unexpected_keys),'\nunexpected_keys:\n', unexpected_keys)

print('model_2.layer4cxq1的取值是:',model_2.layer4cxq1)

控制台输出结果:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 972 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '55348' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test31.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
---------------------------------------------测试模型(CPU)----------------------------------------------

-------------------测试模型state_dict(destination=None, prefix='', keep_vars=False)方法-------------------
函数返回的类型是: 
函数load_state_dict返回的类型是: 
missing_keys的类型:  missing_keys: []
unexpected_keys的类型:  unexpected_keys: []
-------------------------------------------测试当模型参数不匹配时的情况-------------------------------------------
函数load_state_dict返回的类型是: 
missing_keys的类型:  
missing_keys:
 ['attribute4wyf', 'attribute4cjh', 'conv1.0.weight', 'layer4cxq2020.0.weight', 'layer4cxq2020.0.bias', 'layer4cxq0910.weight', 'layer4cxq0910.bias']
unexpected_keys的类型:  
unexpected_keys:
 ['attribute4cxq20200910', 'attribute4cxq20200910ccc']
model_2.layer4cxq1的取值是: None
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>

你可能感兴趣的:(torch.nn.Module.load_state_dict(state_dict, strict=True))