tensorflow实现knn算法

knn算法介绍

knn算法是机器学习中最简单的算法。其原理类似于古语“近朱者赤近墨者黑”,即同类物体的差异性小,异类差异性大,而这种差异往往是用“距离”表示。“距离”的度量一般采用欧氏距离。

这里写图片描述

算法思路

tensorflow实现knn算法_第1张图片

1.计算待分类的样本和样本空间中已标记的样本的欧氏距离。(如图中绿点为待分类样本,要计算绿点与图中所有点的距离)

2.取距离最短的k个点,k个点进行投票,票数最多的类为待测样本的类。(若k为3,则图中实线圆中的点是距离绿点最短的点,其中三角形有两个,正方形1个,所以绿点为三角形;若k为5,则图中虚线中的点为最近邻点,其中正方形有3个,三角形2个,所以绿点为正方形。由此可知,k的取值会影响分类的结果)

算法的优缺点

1.优点

算法简单有效

2.缺点

一方面计算量大。当训练集比较大的时候,每一个样本分类都要计算与所有的已标记样本的距离。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本(例如在样本空间进行划分区域)。另一方面是当已标记样本是不平衡,分类会向占样本多数的类倾斜。解决方案是引进权重。

tensorflow简单实现knn

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials import mnist
mnist_image=mnist.input_data.read_data_sets("C:\\Users\\Administrator\\.keras\\datasets\\",one_hot=True)
pixels,real_values=mnist_image.train.next_batch(10)
# n=5
# image=pixels[n,:]
# image=np.reshape(image, [28,28])
# plt.imshow(image)
# plt.show()
traindata,trainlabel=mnist_image.train.next_batch(100)

testdata,testlabel=mnist_image.test.next_batch(10)
traindata_tensor=tf.placeholder('float',[None,784])
testdata_tensor=tf.placeholder('float',[784])

distance=tf.reduce_sum(tf.abs(tf.add(traindata_tensor,tf.neg(testdata_tensor))),reduction_indices=1)
pred = tf.arg_min(distance,0)
test_num=10
accuracy=0
init=tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for i in range(test_num):
        idx=sess.run(pred,feed_dict={traindata_tensor:traindata,testdata_tensor:testdata[i]})
        print('test No.%d,the real label %d, the predict label %d'%(i,np.argmax(testlabel[i]),np.argmax(trainlabel[idx])))
        if np.argmax(testlabel[i])==np.argmax(trainlabel[idx]):
            accuracy+=1
    print("result:%f"%(1.0*accuracy/test_num))

输出

Extracting C:\Users\Administrator\.keras\datasets\train-images-idx3-ubyte.gz
Extracting C:\Users\Administrator\.keras\datasets\train-labels-idx1-ubyte.gz
Extracting C:\Users\Administrator\.keras\datasets\t10k-images-idx3-ubyte.gz
Extracting C:\Users\Administrator\.keras\datasets\t10k-labels-idx1-ubyte.gz
test No.0,the real label 7, the predict label 7
test No.1,the real label 2, the predict label 2
test No.2,the real label 1, the predict label 1
test No.3,the real label 0, the predict label 0
test No.4,the real label 4, the predict label 4
test No.5,the real label 1, the predict label 1
test No.6,the real label 4, the predict label 4
test No.7,the real label 9, the predict label 9
test No.8,the real label 5, the predict label 9
test No.9,the real label 9, the predict label 9
result:0.900000

你可能感兴趣的:(机器学习实战笔记,tensorflow)