pytorch学习10-网络模型的保存和加载

系列文章目录

  1. pytorch学习1-数据加载以及Tensorboard可视化工具
  2. pytorch学习2-Transforms主要方法使用
  3. pytorch学习3-torchvisin和Dataloader的使用
  4. pytorch学习4-简易卷积实现
  5. pytorch学习5-最大池化层的使用
  6. pytorch学习6-非线性变换(ReLU和sigmoid)
  7. pytorch学习7-序列模型搭建
  8. pytorch学习8-损失函数与反向传播
  9. pytorch学习9-优化器学习
  10. pytorch学习10-网络模型的保存和加载
  11. pytorch学习11-完整的模型训练过程

文章目录

  • 系列文章目录
  • 一、保存方式1,结构+参数
  • 二、保存方式二,只保存参数(官方推荐),空间占用小
  • 总结


一、保存方式1,结构+参数

torch.save(vgg16,"vgg16_method1.pth")#保存了网络模型结构和参数,后面是路径
model=torch.load("vgg16_method1.pth")#加载模型
print(model)

二、保存方式二,只保存参数(官方推荐),空间占用小

torch.save(vgg16.state_dict(),"vgg16_method2.pth")#把模型的参数保存成了字典,并不保存结构
vgg16=torchvision.models.vgg16(pretrained=False)#加载模型。因为这样保存的只有模型的参数,所以需要先恢复模型的结构,
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))#加载模型。用前面这个函数来加载模型的参数
print(vgg16)

总结

以上就是今天要讲的内容,网络模型的保存和加载

你可能感兴趣的:(吴恩达机器学习,pytorch,学习,人工智能)