nearest-neighbor 最近邻算法

用tensorflow实现最近邻算法,对代码进行标注解释。

'''
最邻近算法
'''

from __future__ import print_function

import numpy as np
import tensorflow as tf

#导入数据集
from  tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data",one_hot=True)

#限制数据集数量
Xtr,Ytr = mnist.train.next_batch(5000)#训练集5000
Xte,Yte = mnist.test.next_batch(200)

#图形输入
xtr = tf.placeholder("float",[None,784])
'''
xtr不是一个特定的值,而是一个占位符placeholder,我们在TensorFlow运行计算时输入这个值。
我们希望能够输入任意数量的MNIST图像,每一张图展平成784维的向量。
我们用2维的浮点数张量来表示这些图,这个张量的形状是[None,784 ]。
(这里的None表示此张量的第一个维度可以是任何长度的。)
'''
xte = tf.placeholder("float",[784])

#用L1距离进行最近邻计算
distance = tf.reduce_sum(tf.abs(tf.add(xtr,tf.negative(xte))),reduction_indices=1)

#预测,获取最小的索引
pred = tf.arg_min(distance,0)
accuracy = 0.

#初始化变量(配置默认值)
init = tf.global_variables_initializer()

#开始训练
with tf.Session() as sess:

    #开始初始化
    sess.run(init)

    #循环测试数据
    for i in range(len(Xte)):
        #得到最近邻
        nn_index = sess.run(pred,feed_dict={xtr:Xtr,xte:Xte[i,:]})
        #获取最近邻居类别标签并将其与其真实标签进行比较
        print("Test",i,"Prediction:",np.argmax(Ytr[nn_index]),"True Class:",np.argmax(Yte[i]))
        #计算准确度
        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
            accuracy += 1./len(Xte)
    print("Done!")
    print("Accuracy:",accuracy)

你可能感兴趣的:(技术之路)