#多gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4' #choose
model = TheModelClass(*args, **kwargs)
model = torch.nn.DataParallel(model).cuda()
#加载与训练模型 file.pth.tar
checkpoint = torch.load('file.pth.tar')
model.load_state_dict[checkpoint]
#保存训练模型,保存路径为model_file
torch.save(model.state_dict(), model_file)
1.Missing key(s) in state_dict: Unexpected key(s) in state_dict:
如果加载的预训练模型之前使用了torch.nn.DataParallel(),而此时的训练并没有使用,则会出现这样的错误。【该问题常出现在"训练模型与当前模型参数不完全一致时,需要update,而update又必须在DataParallel之前" 的这种情况下】
我们需要去掉参数中的前缀 module.
#方法1:
model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('myfile.pth').items()})
#方法2:
state_dict = torch.load('file.pth.tar')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
2.Pytorch会由于使用版本不同,加载模型可能会出现“num_batches_tracked”空的键值。导致加载模型不匹配问题。
同理,还有些出现“running"等问题,处理方法相同
#去掉“num_batches_tracked”空的键值
model_dict = net.state_dict()
new_model_dict = {}
for i in model_dict.items():
if "num_batches_tracked" in i[0]:
print (i[0])
else:
new_model_dict[i[0]] = i[1]
pretraine_dict = model.load_state_dict(torch.load('file.pth.tar'))
load_dict={}
for kv1,kv2 in zip(new_model_dict.items(),pretrained_dict.items()):
load_dict[kv1[0]] = kv2[1]
model_dict.update(load_dict)
net.load_state_dict(model_dict)
3.常见的比较简单的是,除了最后层其他层均加载
#如:加载除了 ‘fc’ 层之外的层!
net = vgg().......
model_dict = net.state_dict()
#filter
pretrained_dict = {k: v for k, v in new_state_dict.items() if k.find('fc')==-1}
#update
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)
4.加载的模型的参数 多余 当前模型本身的参数
model_dict = model.state_dict()
pretrained_dict = torch.load('file.pth.tar')
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, vin pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
5.加载的模型参数 少于 当前模型本身的参数
常用于自己改变了网络结构,给网络添加了新的层
model.load_state_dict(checkpoint['state_dict'],strict=False)
#load_state_dict严格匹配参数的键名称
#strict=False表示只加载与键值匹配的参数,并忽略其他参数键。
6.设置部分层不参与训练
for name,param in model.base_model.named_parameters():
if name 满足某些条件:
param.requires_grad = False
#同时,optimizer也需要做相应调整,只优化相应层
params = filter(lambda p: p.requires_grad, model.parameters()
optimizer = torch.optim.SGD(params,
args.lr,
momentum = args.momentum,
weight_decay = args.weight_decay)
7.pytorch版本问题:AttributeError: ‘module’ object has no attribute '_rebuild_tensor_v2’
这是因为训练模型时使用的是新版本的pytorch,而加载时使用的是旧版本的pytorch
解决办法:在代码开头加上:
参考链接:https://discuss.pytorch.org/t/question-about-rebuild-tensor-v2/14560
import torch._utils
try:
torch._utils._rebuild_tensor_v2
except AttributeError:
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
tensor._backward_hooks = backward_hooks
return tensor
torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2