1. 如何从已训练好的网络模型中提取指定层权重
import torch
# vgg为官方提供的model
# https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
import vgg
model = torch.load('logs/vgg16.pkl')
restore_param = ['classifier.2.bias']
# 当然 如果你的目的是不想导入某些层的权重,将下述代码改为`if not k in restore_param`
restore_param = {v for k, v in model.state_dict().items() if k in restore_param}
print(restore_param)
------>:
{tensor([-0.0048, 0.0048], device='cuda:0')}
2. 如何加载模型部分参数并更新
import torch
import vgg
model = torch.load('logs/vgg16.pkl')
vgg16 = vgg.vgg16().cuda()
vgg16_dict = vgg16.state_dict()
for k, v in vgg16_dict.items():
print(v)
print()
print('##################################################################################')
print()
restore = ['classifier.2.bias']
restore_param = {k: v for k, v in model.state_dict().items() if k in restore}
vgg16_dict.update(restore_param)
for k, v in vgg16_dict.items():
print(v)
------>:
tensor([[[[-0.0198, 0.0425, -0.0221],
[ 0.0636, 0.0193, -0.0661],
[-0.0035, 0.0031, -0.0395]],
[[-0.0525, 0.0796, 0.0263],
[-0.0669, 0.1537, 0.1025],
[ 0.0002, -0.0456, -0.0086]],
[[-0.0344, 0.0566, -0.0090],
[ 0.0915, 0.0133, -0.0007],
[-0.0228, -0.0143, 0.0841]]],
...
tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0')
tensor([[ 2.7670e-03, -1.6860e-02, -6.6972e-03, ..., 6.7144e-03,
-7.2912e-03, 2.0684e-03],
[ 4.2978e-03, -9.8524e-03, 1.2163e-02, ..., 6.3420e-03,
-5.1077e-03, 6.4550e-03]], device='cuda:0')
tensor([0., 0.], device='cuda:0')
##################################################################################
tensor([[[[-0.0198, 0.0425, -0.0221],
[ 0.0636, 0.0193, -0.0661],
[-0.0035, 0.0031, -0.0395]],
[[-0.0525, 0.0796, 0.0263],
[-0.0669, 0.1537, 0.1025],
[ 0.0002, -0.0456, -0.0086]],
[[-0.0344, 0.0566, -0.0090],
[ 0.0915, 0.0133, -0.0007],
[-0.0228, -0.0143, 0.0841]]],
...
tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0')
tensor([[ 2.7670e-03, -1.6860e-02, -6.6972e-03, ..., 6.7144e-03,
-7.2912e-03, 2.0684e-03],
[ 4.2978e-03, -9.8524e-03, 1.2163e-02, ..., 6.3420e-03,
-5.1077e-03, 6.4550e-03]], device='cuda:0')
tensor([-0.0048, 0.0048], device='cuda:0')
可以发现
classifier.2.bias的值由[0., 0.]变为了[-0.0048, 0.0048]
参考文章