pytorch中模型的保存与加载:torch.save(),torch.load()

pytorch保存模型与加载:

模型的保存

torch.save(net,PATH)#保存模型的整个网络,包括网络的整个结构和参数
torch.save(net.state_dict,PATH)#只保存网络中的参数

模型的加载

分别对应上边的加载方法。

model_dict=torch.load(PATH)
model_dict=net.load_state_dict(torch.load(PATH))

在自定义的网络中的使用:

import torch
import torch.nn as nn
class neuralModel(nn.Module):
	def __init__(self,device):super(neuralModel,self).__init__()
		self.device=device#初始化函数
	
	def dump(self,filename):#保存模型参数
		torch.save(self.state_dict(),filename)

	def load(self,filename):
		state_dict=torch.load(open(filename,"rb"),map_location=self.device)
		self.load_state_dict(state_dict,strict=True)

其中map_location为改变设备(gpu0,gpu1,cpu…)

参考链接

pytorch------cpu与gpu load时相互转化 torch.load(map_location=)
[Pytorch]Pytorch 保存模型与加载模型(转)

你可能感兴趣的:(python代码有关)