pytorch保存和导入模型

Pytorch官方的加载和保存模型的方式有两种:

1、保存和加载整个模型

这种方式再重新加载的时候不需要自定义网络结构,保存时已经把网络结构保存了下来,比较死板不能调整网络结构。

注:torch.load 返回的是一个 OrderedDict

torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl') 

2、仅保存和加载模型参数(推荐使用) state_dict

这种方式再重新加载的时候需要自己定义网络,并且其中的参数名称与结构要与保存的模型中的一致(可以是部分网络,比如只使用VGG的前几层),相对灵活,便于对网络进行修改。
 

torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

特殊情况:

一、当保存的模型和导入的模型不在一个位置时

假设我们只保存了模型的参数(model.state_dict())到文件名为modelparameters.pth, model = Net()

1. cpu -> cpu或者gpu -> gpu:

checkpoint = torch.load('modelparameters.pth')

model.load_state_dict(checkpoint)

2. cpu -> gpu 1

torch.load('modelparameters.pth', map_location=lambda storage, loc: storage.cuda(1))

3. gpu 1 -> gpu 0

torch.load('modelparameters.pth', map_location={'cuda:1':'cuda:0'})

4. gpu -> cpu

torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)

二、多GPU训练

例如我们创建了一个多GPU训练模型:

model = MyVggNet()
#多GPU并行
model = nn.DataParallel(model).cuda()

此时相当于在原来的模型外面加了一层支持GPU运行的外壳,真正的模型对象为:real_model = model.module。

所以在保存模型的时候注意,如果保存的时候是否带有这层加的外壳,加载的时候也是带有的,如果保存的是真实的模型,加载的也是真是的模型。因为加了module壳的模型在CPU上是不能运行的,因此建议保存真实模型。

1. 第一种方式

模型保存(保存带外壳模型的参数):

real_model = model.module
torch.save(real_model.state_dict(),os.path.join(args.save_path,"cos_mnist_"+str(epoch+1)+"_weight.pth"))

cpu上加载使用:

args.weight=checkpoint/cos_mnist_10_weight.pth
map_location = lambda storage, loc: storage
model.load_state_dict(torch.load(args.weight,map_location=map_location))


2. 第二种方式

模型保存(保存真正模型的参数):

real_model = model.module
save_model(real_model, os.path.join(args.save_path,"cos_mnist_"+str(epoch+1)+"_weight_cpu.pth"))
# 自定义的函数
def save_model(model,filename):
    state = model.state_dict()
    for key in state: state[key] = state[key].clone().cpu()
    torch.save(state, filename)

cpu上加载使用:

args.weight=checkpoint/cos_mnist_10_weight_cpu.pth
model.load_state_dict(torch.load(args.weight))


3. 第三种方式


模型保存(保存整个带壳的模型)

real_model = model.module
torch.save(real_model, os.path.join(args.save_path,"cos_mnist_"+str(epoch+1)+"_whole.pth"))

cpu上加载使用:

args.weight=checkpoint/cos_mnist_10_whole.pth
map_location = lambda storage, loc: storage
model = torch.load(args.weight,map_location=map_location)

 

你可能感兴趣的:(Pytorch)