上一篇我们简单介绍了使用卷积神经网络进行图片分类的算法,点这里,然而最后的效果并不理想,现在我们来考虑如何改进这个算法.
这次加入的功能是保存参数的功能.我们训练神经网络,希望损失值越低越好,test的时候希望精确率越高越好.然而最终的目的还是对一张或者几张图片进行分类,所以我们需要保存训练的参数,tensorflow很方便就有这样的功能.
tf.train.Saver()函数就是用来保存参数的,其中max_to_keep表示要保留的最近检查点(check point)文件的最大数量。创建新文件时,将删除旧文件。如果为None或0,则不从文件系统中删除检查点,但只保留最后一个检查点.注意,您仍然需要调用save()方法来保存模型。将这些参数传递给构造函数不会自动为您保存变量。除了检查点文件之外,保存程序还在磁盘上保留协议缓冲区以及最近的检查点列表。这用于管理编号的检查点文件,通过latest_checkpoint()可以轻松发现最近检查点的路径。该协议缓冲区存储在检查点文件旁边名为“checkpoint”的文件中。
我们通过restore()函数来恢复以前保存的变量。此方法运行构造函数添加的ops以恢复变量。它需要启动图表的会话。要恢复的变量不必初始化,因为恢复本身就是一种初始化变量的方法。
save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True,
strip_default_attrs=False
)
其中 global step数字如果被提供,则会在checkpoint文件名后面 跟上数字
restore(
sess,
save_path
)
这里有一个坑点,网上很多人都遇到过,就是如果多次重复读取,tensorflow会报错,好像是因为tensor name重名的原因,解决方法为关闭重启python(最没技术含量),其次就是加上tf.reset_default_graph()这句话,这里应该要加强对tensorflow的认知程度,tensorflow与其他程序最大的不同就是tensor是以流动的形式存在于每个graph的.
下面是改动的部分,加入了参数存储和读取模块,并把精确度最大的三个值保存起来.
is_train=False
saver=tf.train.Saver(max_to_keep=3)
if is_train:
max_acc=0
f=open('E:/data/save_data/acc.txt','w')
#保存精度最高的三代,并储存为txt文本
for i in range(1000):
# batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={xs: x_train, ys: y_train, keep_prob: 0.5})
if i % 50 == 0:
val_acc=compute_accuracy(x_test,y_test)
val_loss=sess.run(cross_entropy, feed_dict={xs: x_train, ys: y_train, keep_prob: 1})
print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
if val_acc>max_acc:
max_acc=val_acc
saver.save(sess,model_path,global_step=i+1)
f.close
else:
tf.reset_default_graph()
#参数路径
model_file=tf.train.latest_checkpoint('E:/data/save_data/')
#测试文件夹路径
test_path='E:/data/muck truck v2/true'
classfication_dic={0:'合格',1:'不合格'}
saver.restore(sess,model_file)
test_arr=creat_x_database(test_path,32,64)
val_acc=sess.run(prediction,feed_dict={xs: test_arr, keep_prob: 1})
output=[]
#预测值每一行最大值的索引
output=np.argmax(val_acc,1)
#根据索引值通过字典找到分类
for i in range(len(output)):
print('第',i+1,'张图片预测:'+classfication_dic[output[i]])
# print(val_acc)
sess.close()
实际效果如下: