基于opencv和Tensorflow的实时手势识别(3)

在第二部分,我们没有进行参数的优化,后来进行了优化,loss进过几千次迭代,已经下降到0.0002左右,还是比较好的。
基于opencv和Tensorflow的实时手势识别(3)_第1张图片
本次这里介绍测试的代码,并将两个代码融合。

测试代码和训练代码基本是一样的,就不多赘述,直接上代码

def test(X_test, y_test):
    # EVAL_INTERVAL_SECS = 10 # 每10秒加载一次模型,并在测试数据上测试准确率
    with tf.Graph().as_default() as g: # 设置默认graph
        # 定义输入输出格式
        #
        x = tf.placeholder(tf.float32, [None, img_rows, img_cols, img_channels], name='x-input')
        y = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')

        y_ = inference(x, train=None, regularizer=None) # 测试时 不关注正则化损失的值

        # 开始计算正确率
        correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        # 加载模型
        saver = tf.train.Saver()
        with tf.Session() as sess:
            # tf.train.get_checkpoint_state会自动找到目录中的最新模型文件名
            ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                # 得到迭代轮数
                # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] # model.ckpt-3000
                for _ in range(X_test.shape[0]):
                    xs, ys = get_batch(X_test, y_test, batch_size=1) # 测试用
                    # print(ys.shape)
                    label, accuracy_score = sess.run([y_, accuracy], feed_dict={x: xs, y: ys})
                    print("实际手势: %s,  预测手势: %s" % (output[np.argmax(ys)], output[np.argmax(label)]))
                    # print("After %s training steps(s), test accuracy = %f" % (global_step, accuracy_score))

            else:
                print("No checkpoint, Training Firstly.")
                return

注意的是:为了实时显示出预测的结果,我们将测试样本的输入 batch设置为1。

基于opencv和Tensorflow的实时手势识别(3)_第2张图片

代码融合

在代码将两个代码融合时,首先我们不需要测试样本 ,所以不用将所有样本分为测试和训练样本

 X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0, random_state=4) # 将test_size 设置为0

其次,需要再 第一个代码(trackGesture.py)中对采集到的手势进行识别,需要一个函数(类似与上面的test函数)

def Gussgesture(X_test):
    # EVAL_INTERVAL_SECS = 10 # 每10秒加载一次模型,并在测试数据上测试准确率
    with tf.Graph().as_default() as g: # 设置默认graph
        # 定义输入输出格式
        #
        x = tf.placeholder(tf.float32, [None, img_rows, img_cols, img_channels], name='x-input')
        y_ = inference(x, train=None, regularizer=None) # 测试时 不关注正则化损失的值

        # 开始计算正确率
        # correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
        # accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        # 加载模型
        saver = tf.train.Saver()
        with tf.Session() as sess:
            # tf.train.get_checkpoint_state会自动找到目录中的最新模型文件名
            ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                # 得到迭代轮数
                # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] # model.ckpt-3000
                # print(ys.shape)
                label = sess.run(y_, feed_dict={x: X_test.reshape(1, X_test.shape[0], X_test.shape[1], X_test.shape[2])})
                print("预测手势: %s" % (output[np.argmax(label)]))
                # PLOT(label)
                # print("After %s training steps(s), test accuracy = %f" % (global_step, accuracy_score))
                return output[np.argmax(label)]
            else:
                print("No checkpoint, Training Firstly.")
                return

注: 在feed数据时,要按照网络接受的格式输入[batch_size, w, h, channel]

模型调用代码:


        if key == ord('p'):
            """调用模型开始预测, 对二值图像预测,所以要在二值函数里面调用,预测新采集的手势"""
            # print("Prediction Mode - {}".format(guessGesture))
            # Prediction(roi)
            Roi = np.reshape(roi, [width, height, 1])
            # print(Roi.shape)
            gesture = myCNN.Gussgesture(Roi)
            gesture_copy = gesture
            cv2.putText(frame, gesture_copy, (480, 440), font, 1, (0, 0, 255))  # 标注字体

**存在问题:
1 每一次都需要按p才能进行预测。
2 预测的结果在画面上的显示一闪而过,不能保留。**

你可能感兴趣的:(tensorflow)