pytorch 修改权重/字典 key

想对wiograd后的训练添加预训练权重,因为修改卷积层kernel尺寸后,用了新的key名, 所以修改了一下.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import *
import os
import argparse
from model.vggnet_4_bn import VGG

parser=argparse.ArgumentParser()
# parser.add_argument('--pre_weights', type = str, default = 'rename_winograd/rename_winograd.pth', help = 'pretrained weights')
parser.add_argument('--pre_weights', type = str, default = 'ckp_bn_01_vgg4/model_5_0.9785.pth', help = 'pretrained weights')
opt=parser.parse_args()
print(opt)

model = VGG()
model.load_state_dict(torch.load(opt.pre_weights))
model.cuda()
# 修改model的名字为  features.0.weight  ---> features.0.inner_conv2d.weight, 保留 features.0.weight
from collections import OrderedDict
new_dict = OrderedDict()
for key in model.state_dict():
	if key == "features.0.weight":
		new_dict["features.0.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.0.weight"] = model.state_dict()[key]
	elif key == "features.0.bias":
		new_dict["features.0.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.0.bias"] = model.state_dict()[key]
	elif key == "features.4.weight":
		new_dict["features.4.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.4.weight"] = model.state_dict()[key]
	elif key == "features.4.bias":
		new_dict["features.4.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.4.bias"] = model.state_dict()[key]
	elif key == "features.8.weight":
		new_dict["features.8.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.8.weight"] = model.state_dict()[key]
	elif key == "features.8.bias":
		new_dict["features.8.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.8.bias"] = model.state_dict()[key]
	elif key == "features.12.weight":
		new_dict["features.12.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.12.weight"] = model.state_dict()[key]
	elif key == "features.12.bias":
		new_dict["features.12.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.12.bias"] = model.state_dict()[key]
	else:
		new_dict[key] = model.state_dict()[key]
		
print(new_dict.keys())
MODEL_PATH = "/home/aiden00/pytorch_classfication_person/personvscar_pytorch_pq/rename_winograd/" 
if not os.path.exists(MODEL_PATH):
	os.makedirs(MODEL_PATH)	
torch.save(new_dict, MODEL_PATH + 'model_' + 'winograd' + '.pth')  


你可能感兴趣的:(AI,for,CV)