pytorch 保存&读取模型特定参数的两种简单方法

仅做笔记使用


保存

Method1 :只保存模型各层参数

torch.save(net.state_dict(), 'net1.pkl')

其中 net 为模型,将各层参数保存为字典形式。

Method2 :保存整个模型

torch.save(net, 'net1.pkl')

读取

Method1只读取模型各层参数(对应只保存参数情况)

state_dict = torch.load(Path)    #Path 为模型文件的保存路径
net.load_state_dict(state_dict)  #恢复net模型的参数,net为自定义的和载入模型相同结构的网络

所得返回值为字典类型,可以查看各层名称以进行特定参数恢复/导出。

以resnet18为例,查看网络结构名称以及导出最后一个卷积层的权重示例如下

pytorch 保存&读取模型特定参数的两种简单方法_第1张图片pytorch 保存&读取模型特定参数的两种简单方法_第2张图片

 可知所得卷积核尺寸512*512*3*3,导出成功。

Method2:整个模型读取

model = torch.load(Path)        #Path 为模型文件的保存路径
state_dict = model.state_dict() #导出参数字典,剩余操作同上

备注:可直接在torch.load()后添加to(device)使模型导入GPU/CPU中

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load('net.pkl').to(device)  

进阶内容可转:PyTorch模型保存与加载_LXYTSOS的博客-CSDN博客_pytorch保存模型

你可能感兴趣的:(python,pytorch,机器学习)