TensorFlow系列(2)——KNN算法实现

本文的目的是进一步熟悉tensorflow的使用,在MNIST数据集(应用非常广泛的一个入门级计算机视觉数据集)上实现KNN算法,KNN算法的介绍在之前的文章中有写到过(http://blog.csdn.net/flysky1991/article/details/51944482),这里就不详细介绍了。实现代码如下所示:

# -*- coding: utf-8 -*-
"""
Created on Sun Jul  9 21:21:20 2017

@author: Administrator
"""

import numpy as np
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data  #导入MNIST数据集


mnist = input_data.read_data_sets("/tmp/data",one_hot=True)
print(mnist)


#从MNIST数据集中筛选出5000条数据用作测试
train_X,train_Y = mnist.train.next_batch(5000)
#从MNIST数据集中筛选出200条数据用作测试
test_X,test_Y = mnist.test.next_batch(100)

#图输入
train2_X = tf.placeholder("float",[None,784])
test2_X = tf.placeholder("float",[784])

#使用L1距离计算KNN距离计算
distance = tf.reduce_sum(tf.abs(tf.add(train2_X,tf.negative(test2_X))),reduction_indices=1)

#预测:取得最近的邻居节点
pred = tf.arg_min(distance,0)

accuracy = 0

#变量初始化
init = tf.global_variables_initializer()

#启动图
with tf.Session() as sess:
    sess.run(init)
    #遍历测试数据集
    for i in range(len(test_X)):
        #获取最近的邻居节点
        nn_index = sess.run(pred,feed_dict={train2_X:train_X,test2_X:test_X[i,:]})
        #获取最近的邻居节点的类别标签,并将其与该节点的真实类别标签进行比较
        print("测试数据",i,"预测分类:",np.argmax(train_Y[nn_index]),"真实类别:",np.argmax(test_Y[i]))
        #计算准确率
        if np.argmax(train_Y[nn_index]) == np.argmax(test_Y[i]):
            accuracy += 1./len(test_X)
    print("分类准确率为:",accuracy)

运行结果如下图所示:

TensorFlow系列(2)——KNN算法实现_第1张图片

TensorFlow系列(2)——KNN算法实现_第2张图片
上述图片展示的是当训练数据集为5000条时的结果,此时分类准确率为0.94。在保持测试数据集数量不变的情况下,将训练数据集规模改为50000,分类准确率就提升到了0.99.由此可见,训练数据集的规模对算法的性能也有非常明显的影响。

你可能感兴趣的:(tensorflow)