Densenet121预训练权重加载不匹配问题

 手动加载预训练权重出现报错,权重不匹配

model=models.densenet121(pretrained=False)

model_dict = model.state_dict()
pretrained_state = torch.load("./premodel/pre-densenet121.pth")

 原因:

本地权重文件pretrained_state的键名是:“features.denseblock1.denselayer1.norm.1.weight”
模型权文件model_state的键名是:“features.denseblock1.denselayer1.norm1.weight”

修改程序

from collections import OrderedDict
from torchvision import models
from model_dense import densenet121 as createmodel
import torch

model=models.densenet121(pretrained=False)
model_dict = model.state_dict()
pretrained_state = torch.load("./premodel/pre-densenet121.pth")
# 修改 key
new_state_dict = OrderedDict()
for k, v in pretrained_state.items():
    if 'denseblock' in k:
        param = k.split(".")
        k1 = ".".join(param[:-3] + [param[-3] + param[-2]] + [param[-1]])
        new_state_dict[k1] = v
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict)

解释一下为什么这样就可以替换权重名称了

对比pretrained_state和model_state,transition部分和第一个denseblock之前的部分一样,只有denseblock部分有区别,在这里k表示键名,是一个字符串

首先判断是不是denseblock,否不变,是进入if语句,把k表示的键名按照点分开,把除了norm和1之外其他的点加上去

param[:-3] ----------'features', 'denseblock1', 'denselayer1'

param[-3]  ----------'norm'

param[-2]------------'1'

param[-1]------------'weight'

[param[-3] + param[-2]]----------'norm1'

k='features.denseblock1.denselayer1.norm.1.weight'
param = k.split(".")
k1 = ".".join(param[:-3] + [param[-3] + param[-2]] + [param[-1]])
print(param)
print(k1)



['features', 'denseblock1', 'denselayer1', 'norm', '1', 'weight']
features.denseblock1.denselayer1.norm1.weight

你可能感兴趣的:(深度学习,人工智能,python)