从vgg16-397923af.pth里读取的数值应该和加载预训练模型后model.load_state_dict参数一致。
而我的不一致!
原因:在载入参数到模型键值的不匹配,所以使用了strict=False。
解决办法:
params = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()} #替换将要载入的参数的键的不匹配部分
# 进行参数名的映射
import numpy as np
import torch
import torchvision
from torchvision import models
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from src import fcn_resnet50, resnet50
from model_fcn8s import VGG, fcn_vgg16
pretrain_backbone = True
a = ["layer1.0.weight", "layer1.0.bias", "layer1.2.weight", "layer1.2.bias", "layer2.0.weight", "layer2.0.bias", "layer2.2.weight", "layer2.2.bias", "layer3.0.weight", "layer3.0.bias", "layer3.2.weight", "layer3.2.bias", "layer3.4.weight", "layer3.4.bias", "layer4.0.weight", "layer4.0.bias", "layer4.2.weight", "layer4.2.bias", "layer4.4.weight", "layer4.4.bias", "layer5.0.weight", "layer5.0.bias", "layer5.2.weight", "layer5.2.bias", "layer5.4.weight", "layer5.4.bias"]
b = ["features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias"]
struct = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)]
backbone = VGG(num_classes=21, struct=struct)
if pretrain_backbone is True:
weights_dict = torch.load("/home/hyq/hyq/projects/fcn/vgg16-397923af.pth")
model_dict = {}
param_mapping = dict(zip(b, a))
for k, v in weights_dict.items():
if k not in b:
continue
model_dict[param_mapping[k]] = v
# backbone.load_state_dict(torch.load("/home/hyq/hyq/projects/fcn/vgg16-397923af.pth"), strict=False)
backbone.load_state_dict(model_dict)
for name, param in backbone.state_dict().items():
print(f"{name}: {param}")
# print(name, end=' ')
从vgg16-397923af.pth里读取的数值应该和加载预训练模型后model.load_state_dict参数一致
【pytorch载入模型参数报错以及解决办法,小心使用strict=False】