pytorch 训练模型结构及参数保存

保存网络结构及其参数
torch.save(model,‘model.pth’) # 保存
model = torch.load(“model.pth”) # 加载
只加载模型参数,网络结构从代码中创建
torch.save(model.state_dict(),“model.pth”) # 保存参数
model = model() # 代码中创建网络结构
params = torch.load(“model.pth”) # 加载参数
model.load_state_dict(params) # 应用到网络结构中

link.

你可能感兴趣的:(pytorch,pytorch,神经网络)