pytorch中什么是state_dict?如何载入部分权重?

在pytorch中,可学习的参数例如权重和偏置,都在模型的参数中(model.parameters()),而state_dict就是每一层参数组合而成的字典。

state_dict既然是字典,那么就可以对字典进行保存,更新,载入等操作,要注意的是只有那些具有可学习参数的层和register_buffer(训练时不会更新,保存模型时会被保存)在模型的state_dict中有记载。optimizer也有自己的参数字典。

根据官方代码我们创建一个网络:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

我们打印模型的state_dict:

①:我们首先遍历模型net的state_dict,state_dict中包含的就是网络各个层的权重和偏置,在net.state_dict()[param_tensor].size()中,因为state_dict是字典,我们通过字典的键来获得对应的值。(这里的conv1权重大小为(6,3,5,5)是因为卷积核的大小是(5,5))

②:我们遍历优化器,包含了优化器的状态和超参数,

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())

print()

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for key,value in optimizer.state_dict().items():
    print(key, "\t", value)
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

state_dict的应用场景:

我们在训练前使用预训练权重,一般是一个.pth文件,里面就是一个字典,我们载入resnet34的权重:

import torch
pthfile = r'D:/AI预训练权重/train_r34_NBt1D.pth' #faster_rcnn_ckpt.pth
net = torch.load(pthfile,map_location=torch.device('cpu'))
print(type(net))
print(len(net))
for k in net.keys():
    print(k)
print(net["state_dict"])
for key,value in net["state_dict"].items():
    print(key,value,sep=" ")

1
state_dict
OrderedDict([('encoder.conv1.weight', tensor([[[[ 1.5516e-02,  5.2939e-03,  2.8082e-03,  ..., -6.3492e-02,
           -9.9119e-03,  6.2728e-02],
          [ 7.7608e-03,  4.9472e-02,  5.4932e-02,  ..., -1.7819e-01,
           -1.2713e-01,  5.3156e-03],
          [-9.4686e-03,  4.9467e-02,  2.2223e-01,  ..., -1.0941e-01]),
 ('encoder.bn1.weight', tensor([4.1093e-01, 4.1710e-01, 4.3806e-08, 2.7257e-01, 3.0985e-01, 4.4599e-01,
        3.2788e-01, 3.9957e-01, 3.8334e-01, 6.7823e-07, 7.3982e-01, 1.5724e-01,

结果net的state_dict就是一个有序字典,里面包含了每一层名称以及对应的权重。

那如何载入部分权重呢?

我们打印一下字典中的key:

resnet34中conv1即7x7卷积层的key,包括卷积权重,bn。

conv1.weight
bn1.running_mean
bn1.running_var
bn1.weight
bn1.bias

接着第一层包含三层,每层两个3x3,64的卷积。

layer1.0.conv1.weight
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.conv2.weight
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.1.conv1.weight
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.conv2.weight
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.2.conv1.weight
layer1.2.bn1.running_mean
layer1.2.bn1.running_var
layer1.2.bn1.weight
layer1.2.bn1.bias
layer1.2.conv2.weight
layer1.2.bn2.running_mean
layer1.2.bn2.running_var
layer1.2.bn2.weight
layer1.2.bn2.bias

第二层包含4层,每次包含两个3x3,128的卷积。注意这里有一个downsample,是因为采用了basicblock中的下采样,即采用1x1,128卷积层。

layer2.0.conv1.weight
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.conv2.weight
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.downsample.0.weight
layer2.0.downsample.1.running_mean
layer2.0.downsample.1.running_var
layer2.0.downsample.1.weight
layer2.0.downsample.1.bias
layer2.1.conv1.weight
layer2.1.bn1.running_mean
layer2.1.bn1.running_var
layer2.1.bn1.weight
layer2.1.bn1.bias
layer2.1.conv2.weight
layer2.1.bn2.running_mean
layer2.1.bn2.running_var
layer2.1.bn2.weight
layer2.1.bn2.bias
layer2.2.conv1.weight
layer2.2.bn1.running_mean
layer2.2.bn1.running_var
layer2.2.bn1.weight
layer2.2.bn1.bias
layer2.2.conv2.weight
layer2.2.bn2.running_mean
layer2.2.bn2.running_var
layer2.2.bn2.weight
layer2.2.bn2.bias
layer2.3.conv1.weight
layer2.3.bn1.running_mean
layer2.3.bn1.running_var
layer2.3.bn1.weight
layer2.3.bn1.bias
layer2.3.conv2.weight
layer2.3.bn2.running_mean
layer2.3.bn2.running_var
layer2.3.bn2.weight
layer2.3.bn2.bias

第三层包含六层,每层有两个3x3,256的卷积。

layer3.0.conv1.weight
layer3.0.bn1.running_mean
layer3.0.bn1.running_var
layer3.0.bn1.weight
layer3.0.bn1.bias
layer3.0.conv2.weight
layer3.0.bn2.running_mean
layer3.0.bn2.running_var
layer3.0.bn2.weight
layer3.0.bn2.bias
layer3.0.downsample.0.weight
layer3.0.downsample.1.running_mean
layer3.0.downsample.1.running_var
layer3.0.downsample.1.weight
layer3.0.downsample.1.bias
layer3.1.conv1.weight
layer3.1.bn1.running_mean
layer3.1.bn1.running_var
layer3.1.bn1.weight
layer3.1.bn1.bias
layer3.1.conv2.weight
layer3.1.bn2.running_mean
layer3.1.bn2.running_var
layer3.1.bn2.weight
layer3.1.bn2.bias
layer3.2.conv1.weight
layer3.2.bn1.running_mean
layer3.2.bn1.running_var
layer3.2.bn1.weight
layer3.2.bn1.bias
layer3.2.conv2.weight
layer3.2.bn2.running_mean
layer3.2.bn2.running_var
layer3.2.bn2.weight
layer3.2.bn2.bias
layer3.3.conv1.weight
layer3.3.bn1.running_mean
layer3.3.bn1.running_var
layer3.3.bn1.weight
layer3.3.bn1.bias
layer3.3.conv2.weight
layer3.3.bn2.running_mean
layer3.3.bn2.running_var
layer3.3.bn2.weight
layer3.3.bn2.bias
layer3.4.conv1.weight
layer3.4.bn1.running_mean
layer3.4.bn1.running_var
layer3.4.bn1.weight
layer3.4.bn1.bias
layer3.4.conv2.weight
layer3.4.bn2.running_mean
layer3.4.bn2.running_var
layer3.4.bn2.weight
layer3.4.bn2.bias
layer3.5.conv1.weight
layer3.5.bn1.running_mean
layer3.5.bn1.running_var
layer3.5.bn1.weight
layer3.5.bn1.bias
layer3.5.conv2.weight
layer3.5.bn2.running_mean
layer3.5.bn2.running_var
layer3.5.bn2.weight
layer3.5.bn2.bias

第四层包含三层,每层3x3,512的卷积。以及最后的全连接层。

layer4.0.conv1.weight
layer4.0.bn1.running_mean
layer4.0.bn1.running_var
layer4.0.bn1.weight
layer4.0.bn1.bias
layer4.0.conv2.weight
layer4.0.bn2.running_mean
layer4.0.bn2.running_var
layer4.0.bn2.weight
layer4.0.bn2.bias
layer4.0.downsample.0.weight
layer4.0.downsample.1.running_mean
layer4.0.downsample.1.running_var
layer4.0.downsample.1.weight
layer4.0.downsample.1.bias
layer4.1.conv1.weight
layer4.1.bn1.running_mean
layer4.1.bn1.running_var
layer4.1.bn1.weight
layer4.1.bn1.bias
layer4.1.conv2.weight
layer4.1.bn2.running_mean
layer4.1.bn2.running_var
layer4.1.bn2.weight
layer4.1.bn2.bias
layer4.2.conv1.weight
layer4.2.bn1.running_mean
layer4.2.bn1.running_var
layer4.2.bn1.weight
layer4.2.bn1.bias
layer4.2.conv2.weight
layer4.2.bn2.running_mean
layer4.2.bn2.running_var
layer4.2.bn2.weight
layer4.2.bn2.bias
fc.weight
fc.bias

载入权重后,我们遍历权重,当遇到“fc”键时,将其加入空列表,然后将k的每一层与del_key每一层比较,如果相同,删除掉权重中相对应的值,删除后再载入我们自己的权重。代码参考b导霹雳吧啦。

del_key = []
    for key, _ in pre_weights.items():
        print(key)
        if "fc" in key:
            del_key.append(key)
            print(del_key)
    for key in del_key:
        del pre_weights[key]
        print(pre_weights)
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)

重新打印,我们发现最后一层fc层全部被删除掉了。

('layer4.2.bn2.bias', Parameter containing:
tensor([ 0.1216,  0.1289,  0.1926,  0.1332,  0.0978,  0.1507,  0.1391,  0.1332,

你可能感兴趣的:(pytorch函数,pytorch,python,深度学习)