pytorch模型保存与加载以及常见问题

一. 模型保存与加载

#多gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4' #choose
model = TheModelClass(*args, **kwargs)
model = torch.nn.DataParallel(model).cuda()

#加载与训练模型 file.pth.tar
checkpoint = torch.load('file.pth.tar')
model.load_state_dict[checkpoint]

#保存训练模型,保存路径为model_file
torch.save(model.state_dict(), model_file)

二.常见问题

1.Missing key(s) in state_dict: Unexpected key(s) in state_dict:
如果加载的预训练模型之前使用了torch.nn.DataParallel(),而此时的训练并没有使用,则会出现这样的错误。【该问题常出现在"训练模型与当前模型参数不完全一致时,需要update,而update又必须在DataParallel之前" 的这种情况下】
我们需要去掉参数中的前缀 module.

#方法1:
model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('myfile.pth').items()})
#方法2:
state_dict = torch.load('file.pth.tar')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

2.Pytorch会由于使用版本不同,加载模型可能会出现“num_batches_tracked”空的键值。导致加载模型不匹配问题。
同理,还有些出现“running"等问题,处理方法相同

#去掉“num_batches_tracked”空的键值
model_dict = net.state_dict()
new_model_dict = {}
for i in model_dict.items():
    if "num_batches_tracked" in i[0]:
        print (i[0])
    else:
        new_model_dict[i[0]] = i[1]
        
pretraine_dict = model.load_state_dict(torch.load('file.pth.tar'))
load_dict={}

for kv1,kv2 in zip(new_model_dict.items(),pretrained_dict.items()):
    load_dict[kv1[0]] = kv2[1]
model_dict.update(load_dict)
net.load_state_dict(model_dict)

3.常见的比较简单的是,除了最后层其他层均加载

#如:加载除了 ‘fc’ 层之外的层!
net = vgg().......
model_dict = net.state_dict()
#filter
pretrained_dict = {k: v for k, v in new_state_dict.items() if k.find('fc')==-1}
#update
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

4.加载的模型的参数 多余 当前模型本身的参数

model_dict = model.state_dict()
pretrained_dict = torch.load('file.pth.tar')
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, vin pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

5.加载的模型参数 少于 当前模型本身的参数
常用于自己改变了网络结构,给网络添加了新的层

model.load_state_dict(checkpoint['state_dict'],strict=False)
#load_state_dict严格匹配参数的键名称
#strict=False表示只加载与键值匹配的参数,并忽略其他参数键。

6.设置部分层不参与训练

for name,param in model.base_model.named_parameters():
    if name 满足某些条件:
        param.requires_grad = False

#同时,optimizer也需要做相应调整,只优化相应层
params = filter(lambda p: p.requires_grad, model.parameters()
optimizer = torch.optim.SGD(params,
                            args.lr,
                            momentum = args.momentum,
                            weight_decay = args.weight_decay)

7.pytorch版本问题:AttributeError: ‘module’ object has no attribute '_rebuild_tensor_v2’
这是因为训练模型时使用的是新版本的pytorch,而加载时使用的是旧版本的pytorch
解决办法:在代码开头加上:
参考链接:https://discuss.pytorch.org/t/question-about-rebuild-tensor-v2/14560

import torch._utils
try:
    torch._utils._rebuild_tensor_v2
except AttributeError:
    def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
        tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
        tensor.requires_grad = requires_grad
        tensor._backward_hooks = backward_hooks
        return tensor
    torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

你可能感兴趣的:(pytorch学习)