搭建BP网络对实际图片进行预测

在mnist数据集上搭建BP神经网络,完成在测试集上的训练,详情参看:基于TensorFlow的mnist数据集BP网络搭建

那么我们思考:1.能不能对实际图片进行预测? 2.能不能用自定义的图片数据集进行预测?

首先看问题1,为了方便训练在mnist_backward.py中加入断点续训,这样在恢复训练后能继续上次的训练轮数,不必再重新开始:

搭建BP网络对实际图片进行预测_第1张图片
断点续训

其中,tf.train.get_checkpoint_state(checkpoint_dir,latestfile=None)表示如果文件夹包含有效断点文件则返回该文件,saver.restore(sess,ckpt.model_checkpoint_path)是恢复当前会话,将ckpt中的最新的w,b付给当前会话中。

在完成断点续训后,我们继续解决问题1,对实际图片进行预测,总共分2部:1)输入实际图片,对图片进行预处理使其符合NN要求;2)复现NN喂入图片。

先看步骤1),我们对图片进行简单的预处理,灰度化和二值化,使其大小为28*28,像素值在0-1之间黑底白字,大小是1行784列的数组。

搭建BP网络对实际图片进行预测_第2张图片
预处理

步骤2)复现神经网络,

搭建BP网络对实际图片进行预测_第3张图片
复现NN

还是,首先创建默认图,输入x占位,调用forward输出得到概率最大的索引值即预测y,实现滑动平均,传入更新速度moving_average_decay,并把每次更新得到的w,b和影子值保存。开启会话,根据chekpoint文件找到最新模型,喂入图片并计算返回预测值。

这两个步骤实现后我们只需在应用模块中调用就可:

搭建BP网络对实际图片进行预测_第4张图片
应用

在训练模型达到一定准确度后,运行application,就可以完成对实际图片的预测:

搭建BP网络对实际图片进行预测_第5张图片

在下一篇,继续分析问题2,对自定义的图片数据集进行训练。



新手学习,欢迎指教!!

你可能感兴趣的:(搭建BP网络对实际图片进行预测)