pytorch的torch.load()与torch.save()

33.pytorch保存和加载模型参数 总结 参考

(1) 仅仅保存和加载模型参数,保存时保存的是一个字典,字典里面包括状态字典等。

torch.save(the_model.state_dict(), PATH)

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

(2)保存和加载整个模型,保存整个模型。

torch.save(the_model, PATH)

the_model = torch.load(PATH)

(3)load时应该注意load的对象的类型是否一致,或者load的状态字典的结构是否一致。

参考

(4)tensor在cpu与gpu上迁移。参考      参考

>>> torch.load('tensors.pt')

# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
# Load tensor from io.BytesIO object
>>> with open('tensor.pt', 'rb') as f:
        buffer = io.BytesIO(f.read())
>>> torch.load(buffer)
# Load a module with 'ascii' encoding for unpickling
>>> torch.load('module.pt', encoding='ascii')

(5)torch.save(),直接将模型保存为.pth或.t7格式都可以。参考

print('===> Saving models...')
state = {
    'state': model.state_dict(),
    'epoch': epoch                   # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
    os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

你可能感兴趣的:(学习问题,人工智能)