tensorflow实例(9)--最邻近算法实现MNIST手写数字分类算法

KNN(k-NearestNeighbor)是监督学习的分类技术中最简单的方法之一,K指k个最近的邻居的意思,
关于KNN的详细基本实现原理,可参考  机器学习(2)--邻近算法(KNN)

这篇文章是用tensorflow来实现,但由于在算出所有点与点之间距离后取出最近的K个时

如果只使用tensorflow,我只能实现取得1个,无法取到K个

因此这里我不再设置K,只取最近的一个,因此我的标题也是最邻近算法

MNIST 是Tensoflow提供的一个入门级的计算机视觉数据集,分为两部分(训练集和测试集

其中训练集共55000张,测试集共10000张,当为None时随机读取 

点击此处下载Mnist数据包

关于应用MNIST数据包的其他分类算法可参考

TensorFlow实例(4)--MNIST简介及手写数字分类算法

TensorFlow实例(5.1)--MNIST手写数字进阶算法(卷积神经网络CNN)


# -*- coding:utf-8 -*-
import tensorflow as tf 
import tensorflow.examples.tutorials.mnist.input_data as input_data
import random

#读取mnist数据,下载后的Mnist并解压后,放在项目的同级目录下,通过下面程序即可读取
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

x_train,y_train = mnist.train.next_batch(5000)
x_test,y_test = mnist.test.next_batch(500)

x_train_pl = tf.placeholder(tf.float32,[None,784])#因为每次是和所有的训练集比对,所以是二维数组,None表示不定维度
y_train_pl = tf.placeholder(tf.int16,[None,10])
x_test_pl = tf.placeholder(tf.float32,[784]) #因为每次传入的测试只有一条,所以是一维数组
y_test_pl = tf.placeholder(tf.int16,[10])
#取测试数据与x_train中最近的那个点的序号
nearestIndex = tf.argmin(tf.reduce_sum(tf.pow((x_train_pl - x_test_pl),2),axis=1))  #完整的距离公式还应该有一个tf.sqrt,但因为在比较时开方无意义,所以省略
isRight = tf.equal(tf.argmax(y_train_pl[nearestIndex]),tf.argmax(y_test_pl))#判断是否正确
sess = tf.Session()
sess.run(tf.global_variables_initializer())

rightCount = 0
for i in range(len(x_test)):
    if sess.run(isRight,feed_dict={
        x_train_pl:x_train
        ,y_train_pl:y_train
        ,x_test_pl:x_test[i,:]
        ,y_test_pl:y_test[i,:]
        }) :
        rightCount+=1
    if (i + 1) % 100 == 0   :
        print("已完成%d条记录!" % (i + 1))
print("共取得测试集%d条,测试正确%d条," % (len(x_test),rightCount) + "正确率:" + str(round(rightCount / x_test.shape[0] * 100,2)) + "%")


你可能感兴趣的:(python,机器学习)