pytorch: 保存和读取参数和模型

一、保存和读取参数

1、当训练完后,把当前的参数保存下来

import torch
torch.save(net.state_dict(), path)

保存参数只需用到torch.save(),其中net为自定义的模型名称,其子参数state_dict()为模型的参数,path为保存的路径加名称,其后缀为 ptpth ,如: ‘pth/net_parameters.pth’。

2、加载参数

import torch
net.load_state_dict(torch.load(path))

二、保存和读取模型

保存和读取模型是把模型的网络架构以及其参数都进行保存和读取

import torch
torch.save(net, path) # 保存模型
net_ = torch.load(pth) # 读取模型

同样地, net为自定义的模型, 而net_为新加载的模型,path为路径和保存模型的名称,后缀为 ptpth

你可能感兴趣的:(pytorch,pytorch,神经网络,深度学习)