基于cnn的图像二分类算法(二)

 

上一篇我们简单介绍了使用卷积神经网络进行图片分类的算法,点这里,然而最后的效果并不理想,现在我们来考虑如何改进这个算法.

这次加入的功能是保存参数的功能.我们训练神经网络,希望损失值越低越好,test的时候希望精确率越高越好.然而最终的目的还是对一张或者几张图片进行分类,所以我们需要保存训练的参数,tensorflow很方便就有这样的功能.

tf.train.Saver()函数就是用来保存参数的,其中max_to_keep表示要保留的最近检查点(check point)文件的最大数量。创建新文件时,将删除旧文件。如果为None或0,则不从文件系统中删除检查点,但只保留最后一个检查点.注意,您仍然需要调用save()方法来保存模型。将这些参数传递给构造函数不会自动为您保存变量。除了检查点文件之外,保存程序还在磁盘上保留协议缓冲区以及最近的检查点列表。这用于管理编号的检查点文件,通过latest_checkpoint()可以轻松发现最近检查点的路径。该协议缓冲区存储在检查点文件旁边名为“checkpoint”的文件中。

我们通过restore()函数来恢复以前保存的变量。此方法运行构造函数添加的ops以恢复变量。它需要启动图表的会话。要恢复的变量不必初始化,因为恢复本身就是一种初始化变量的方法。

save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True,
    strip_default_attrs=False
)

其中 global step数字如果被提供,则会在checkpoint文件名后面 跟上数字

restore(
    sess,
    save_path
)

 这里有一个坑点,网上很多人都遇到过,就是如果多次重复读取,tensorflow会报错,好像是因为tensor name重名的原因,解决方法为关闭重启python(最没技术含量),其次就是加上tf.reset_default_graph()这句话,这里应该要加强对tensorflow的认知程度,tensorflow与其他程序最大的不同就是tensor是以流动的形式存在于每个graph的.

下面是改动的部分,加入了参数存储和读取模块,并把精确度最大的三个值保存起来.

is_train=False
saver=tf.train.Saver(max_to_keep=3)

if is_train:
    max_acc=0
    f=open('E:/data/save_data/acc.txt','w')
#保存精度最高的三代,并储存为txt文本

    for i in range(1000):
#    batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={xs: x_train, ys: y_train, keep_prob: 0.5})
        if i % 50 == 0:
            val_acc=compute_accuracy(x_test,y_test)
            val_loss=sess.run(cross_entropy, feed_dict={xs: x_train, ys: y_train, keep_prob: 1})
            print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
            f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
            if val_acc>max_acc:
                max_acc=val_acc
                saver.save(sess,model_path,global_step=i+1)
    f.close
    
else:
    tf.reset_default_graph()
    #参数路径
    model_file=tf.train.latest_checkpoint('E:/data/save_data/')
    #测试文件夹路径
    test_path='E:/data/muck truck v2/true'
    
    classfication_dic={0:'合格',1:'不合格'}
    saver.restore(sess,model_file)
    test_arr=creat_x_database(test_path,32,64)
    val_acc=sess.run(prediction,feed_dict={xs: test_arr, keep_prob: 1})
    output=[]
    #预测值每一行最大值的索引
    output=np.argmax(val_acc,1)
    #根据索引值通过字典找到分类
    for i in range(len(output)):
        print('第',i+1,'张图片预测:'+classfication_dic[output[i]])
#    print(val_acc)
    
    

sess.close()   

实际效果如下:

基于cnn的图像二分类算法(二)_第1张图片

你可能感兴趣的:(ML)