在pytorch中,可学习的参数例如权重和偏置,都在模型的参数中(model.parameters()),而state_dict就是每一层参数组合而成的字典。
state_dict既然是字典,那么就可以对字典进行保存,更新,载入等操作,要注意的是只有那些具有可学习参数的层和register_buffer(训练时不会更新,保存模型时会被保存)在模型的state_dict中有记载。optimizer也有自己的参数字典。
根据官方代码我们创建一个网络:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
我们打印模型的state_dict:
①:我们首先遍历模型net的state_dict,state_dict中包含的就是网络各个层的权重和偏置,在net.state_dict()[param_tensor].size()中,因为state_dict是字典,我们通过字典的键来获得对应的值。(这里的conv1权重大小为(6,3,5,5)是因为卷积核的大小是(5,5))
②:我们遍历优化器,包含了优化器的状态和超参数,
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in net.state_dict():
print(param_tensor, "\t", net.state_dict()[param_tensor].size())
print()
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for key,value in optimizer.state_dict().items():
print(key, "\t", value)
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
state_dict的应用场景:
我们在训练前使用预训练权重,一般是一个.pth文件,里面就是一个字典,我们载入resnet34的权重:
import torch
pthfile = r'D:/AI预训练权重/train_r34_NBt1D.pth' #faster_rcnn_ckpt.pth
net = torch.load(pthfile,map_location=torch.device('cpu'))
print(type(net))
print(len(net))
for k in net.keys():
print(k)
print(net["state_dict"])
for key,value in net["state_dict"].items():
print(key,value,sep=" ")
1
state_dict
OrderedDict([('encoder.conv1.weight', tensor([[[[ 1.5516e-02, 5.2939e-03, 2.8082e-03, ..., -6.3492e-02,
-9.9119e-03, 6.2728e-02],
[ 7.7608e-03, 4.9472e-02, 5.4932e-02, ..., -1.7819e-01,
-1.2713e-01, 5.3156e-03],
[-9.4686e-03, 4.9467e-02, 2.2223e-01, ..., -1.0941e-01]),
('encoder.bn1.weight', tensor([4.1093e-01, 4.1710e-01, 4.3806e-08, 2.7257e-01, 3.0985e-01, 4.4599e-01,
3.2788e-01, 3.9957e-01, 3.8334e-01, 6.7823e-07, 7.3982e-01, 1.5724e-01,
结果net的state_dict就是一个有序字典,里面包含了每一层名称以及对应的权重。
那如何载入部分权重呢?
我们打印一下字典中的key:
resnet34中conv1即7x7卷积层的key,包括卷积权重,bn。
conv1.weight
bn1.running_mean
bn1.running_var
bn1.weight
bn1.bias
接着第一层包含三层,每层两个3x3,64的卷积。
layer1.0.conv1.weight
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.conv2.weight
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.1.conv1.weight
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.conv2.weight
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.2.conv1.weight
layer1.2.bn1.running_mean
layer1.2.bn1.running_var
layer1.2.bn1.weight
layer1.2.bn1.bias
layer1.2.conv2.weight
layer1.2.bn2.running_mean
layer1.2.bn2.running_var
layer1.2.bn2.weight
layer1.2.bn2.bias
第二层包含4层,每次包含两个3x3,128的卷积。注意这里有一个downsample,是因为采用了basicblock中的下采样,即采用1x1,128卷积层。
layer2.0.conv1.weight
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.conv2.weight
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.downsample.0.weight
layer2.0.downsample.1.running_mean
layer2.0.downsample.1.running_var
layer2.0.downsample.1.weight
layer2.0.downsample.1.bias
layer2.1.conv1.weight
layer2.1.bn1.running_mean
layer2.1.bn1.running_var
layer2.1.bn1.weight
layer2.1.bn1.bias
layer2.1.conv2.weight
layer2.1.bn2.running_mean
layer2.1.bn2.running_var
layer2.1.bn2.weight
layer2.1.bn2.bias
layer2.2.conv1.weight
layer2.2.bn1.running_mean
layer2.2.bn1.running_var
layer2.2.bn1.weight
layer2.2.bn1.bias
layer2.2.conv2.weight
layer2.2.bn2.running_mean
layer2.2.bn2.running_var
layer2.2.bn2.weight
layer2.2.bn2.bias
layer2.3.conv1.weight
layer2.3.bn1.running_mean
layer2.3.bn1.running_var
layer2.3.bn1.weight
layer2.3.bn1.bias
layer2.3.conv2.weight
layer2.3.bn2.running_mean
layer2.3.bn2.running_var
layer2.3.bn2.weight
layer2.3.bn2.bias
第三层包含六层,每层有两个3x3,256的卷积。
layer3.0.conv1.weight
layer3.0.bn1.running_mean
layer3.0.bn1.running_var
layer3.0.bn1.weight
layer3.0.bn1.bias
layer3.0.conv2.weight
layer3.0.bn2.running_mean
layer3.0.bn2.running_var
layer3.0.bn2.weight
layer3.0.bn2.bias
layer3.0.downsample.0.weight
layer3.0.downsample.1.running_mean
layer3.0.downsample.1.running_var
layer3.0.downsample.1.weight
layer3.0.downsample.1.bias
layer3.1.conv1.weight
layer3.1.bn1.running_mean
layer3.1.bn1.running_var
layer3.1.bn1.weight
layer3.1.bn1.bias
layer3.1.conv2.weight
layer3.1.bn2.running_mean
layer3.1.bn2.running_var
layer3.1.bn2.weight
layer3.1.bn2.bias
layer3.2.conv1.weight
layer3.2.bn1.running_mean
layer3.2.bn1.running_var
layer3.2.bn1.weight
layer3.2.bn1.bias
layer3.2.conv2.weight
layer3.2.bn2.running_mean
layer3.2.bn2.running_var
layer3.2.bn2.weight
layer3.2.bn2.bias
layer3.3.conv1.weight
layer3.3.bn1.running_mean
layer3.3.bn1.running_var
layer3.3.bn1.weight
layer3.3.bn1.bias
layer3.3.conv2.weight
layer3.3.bn2.running_mean
layer3.3.bn2.running_var
layer3.3.bn2.weight
layer3.3.bn2.bias
layer3.4.conv1.weight
layer3.4.bn1.running_mean
layer3.4.bn1.running_var
layer3.4.bn1.weight
layer3.4.bn1.bias
layer3.4.conv2.weight
layer3.4.bn2.running_mean
layer3.4.bn2.running_var
layer3.4.bn2.weight
layer3.4.bn2.bias
layer3.5.conv1.weight
layer3.5.bn1.running_mean
layer3.5.bn1.running_var
layer3.5.bn1.weight
layer3.5.bn1.bias
layer3.5.conv2.weight
layer3.5.bn2.running_mean
layer3.5.bn2.running_var
layer3.5.bn2.weight
layer3.5.bn2.bias
第四层包含三层,每层3x3,512的卷积。以及最后的全连接层。
layer4.0.conv1.weight
layer4.0.bn1.running_mean
layer4.0.bn1.running_var
layer4.0.bn1.weight
layer4.0.bn1.bias
layer4.0.conv2.weight
layer4.0.bn2.running_mean
layer4.0.bn2.running_var
layer4.0.bn2.weight
layer4.0.bn2.bias
layer4.0.downsample.0.weight
layer4.0.downsample.1.running_mean
layer4.0.downsample.1.running_var
layer4.0.downsample.1.weight
layer4.0.downsample.1.bias
layer4.1.conv1.weight
layer4.1.bn1.running_mean
layer4.1.bn1.running_var
layer4.1.bn1.weight
layer4.1.bn1.bias
layer4.1.conv2.weight
layer4.1.bn2.running_mean
layer4.1.bn2.running_var
layer4.1.bn2.weight
layer4.1.bn2.bias
layer4.2.conv1.weight
layer4.2.bn1.running_mean
layer4.2.bn1.running_var
layer4.2.bn1.weight
layer4.2.bn1.bias
layer4.2.conv2.weight
layer4.2.bn2.running_mean
layer4.2.bn2.running_var
layer4.2.bn2.weight
layer4.2.bn2.bias
fc.weight
fc.bias
载入权重后,我们遍历权重,当遇到“fc”键时,将其加入空列表,然后将k的每一层与del_key每一层比较,如果相同,删除掉权重中相对应的值,删除后再载入我们自己的权重。代码参考b导霹雳吧啦。
del_key = []
for key, _ in pre_weights.items():
print(key)
if "fc" in key:
del_key.append(key)
print(del_key)
for key in del_key:
del pre_weights[key]
print(pre_weights)
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
重新打印,我们发现最后一层fc层全部被删除掉了。
('layer4.2.bn2.bias', Parameter containing:
tensor([ 0.1216, 0.1289, 0.1926, 0.1332, 0.0978, 0.1507, 0.1391, 0.1332,