卷积神经网络简单的应用(三):模型测试

  1. 模型测试
    模型训练好之后通过重新加载模型的方式进行模型测试,使用Tensorflow中的Saver对象。相关代码如下:
    def test_cnn(x_data):
        output = create_cnn(4)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            #加载训练好的模型
            saver.restore(sess,"./model/cnn.model-2100")
            preject = tf.argmax(output,1)
            x_in = np.array(x_data)
            #keep_prob设置为1
            label = sess.run(preject,feed_dict={X:[x_in],keep_prob:1})
        return label
    主函数为:
    if __name__ == '__main__':
        isTrain = 2    
        if 1 == isTrain:        
            X = tf.placeholder(tf.float32,[None,200,150,3])
            Y = tf.placeholder(tf.float32,[None,4])
            
            keep_prob = tf.placeholder(tf.float32)
            train_cnn(xdata,ydata)
        if 2 == isTrain:
            #将测试数据放在相应的文件中    
            path_list = ['./0','./1','./2','./3']
            for p in path_list:
                file_info = os.listdir(p)
                for file_name in file_info:
                    x_data = read_test_data(p+'/'+file_name)
                    if type(x_data) == type(None):
                        print('==>',p)
                        continue
                    #没有这句,会出现问题
                    tf.reset_default_graph()  
                    X = tf.placeholder(tf.float32,[None,200,150,3])
                    keep_prob = tf.placeholder(tf.float32)
                    l = test_cnn(x_data)
                    label = ['girl','beauty girl','boy','handsome boy']               
                    print(p,':',file_name,'====>',label[l[0]])



你可能感兴趣的:(卷积神经网络简单的应用)