卷积神经网络CIFAR-10 + TensorFlow-gpu 1.8.0训练与评估

参考:

中文网站极客学院也有该部分的汉译版:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/deep_cnn.html

网友学习经验帖:https://blog.csdn.net/yhl_leo/article/details/50738311

版本报错修改:https://blog.csdn.net/zeuseign/article/details/72771598

目录

1.介绍

2.准备

3.训练

4.评估


1.介绍

对CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题,其任务是对一组大小为32x32的RGB图像进行分类,这些图像涵盖了10个类别:飞机,汽车,鸟,猫,鹿,狗,青蛙,马,船以及卡车。

卷积神经网络CIFAR-10 + TensorFlow-gpu 1.8.0训练与评估_第1张图片

2.准备

代码:https://github.com/sjl3110/Cifar-10-TensorFlow1.8.0

文件 作用
cifar10_input.py 读取本地CIFAR-10的二进制文件格式的内容。
cifar10.py 建立CIFAR-10的模型。
cifar10_train.py 在CPU或GPU上训练CIFAR-10的模型。
cifar10_multi_gpu_train.py 在多GPU上训练CIFAR-10的模型。
cifar10_eval.py 评估CIFAR-10模型的预测性能。

本代码适用于1.8版本的tensorflow,因为API较老版本变化很大,因此对于官网给的代码做出了一定的修改。除此之外,迭代次数修改为20000次。由于笔记本训练,实际上1060的GPU下20000次迭代就花费了十几分钟的时间。

3.训练

直接运行,开始会下载CIFAR的数据包(官网链接http://www.cs.toronto.edu/~kriz/cifar.html,下载第三个),大概国内得3小时...然后没报错的话基本上就可以开始迭代学习了

python3 cifar10_train.py

卷积神经网络CIFAR-10 + TensorFlow-gpu 1.8.0训练与评估_第2张图片

4.评估

脚本文件cifar10_eval.py对模型进行了评估,利用 inference()函数重构模型,并使用了在评估数据集所有10,000张CIFAR-10图片进行测试。最终计算出的精度为1:N,N=预测值中置信度最高的一项与图片真实label匹配的频次。

为了监控模型在训练过程中的改进情况,评估用的脚本文件会周期性的在最新的检查点文件上运行,这些检查点文件是由cifar10_train.py产生。

python3 cifar10_eval.py

命令运行后会显示当前的精度,一般同时运行训练和评估会消耗大量的系统资源。

你可能感兴趣的:(TensorFlow)