kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。
一、简介
右图中,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。
KNN算法的决策过程
K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
KNN方法虽然从原理上也依赖于极限定理,但在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的 权值(weight),如权值与距离成反比。
二、算法流程
1. 准备数据,对数据进行 预处理
2. 选用合适的数据结构存储训练数据和测试元组
3. 设定参数,如k
4.维护一个大小为k的的按距离由大到小的 优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
5. 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L 与优先级队列中的最大距离Lmax
6. 进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
7. 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。
8. 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。
三、优点
1.简单,易于理解,易于实现,无需估计参数,无需训练;
2. 适合对稀有事件进行分类;
3.特别适合于多分类问题(multi-modal,对象具有多个类别标签), kNN比SVM的表现要好。
四、缺点
该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。 该算法只计算“最近的”邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论怎样,数量并不能影响运行结果。
该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。
可理解性差,无法给出像决策树那样的规则。
五、改进策略
kNN算法因其提出时间较早,随着其他技术的不断更新和完善,kNN算法的诸多不足之处也逐渐显露,因此许多kNN算法的改进算法也应运而生。
针对以上算法的不足,算法的改进方向主要分成了分类效率和分类效果两方面。
分类效率:事先对样本属性进行约简,删除对分类结果影响较小的属性,快速的得出待分类样本的类别。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。
分类效果:采用权值的方法(和该样本距离小的邻居权值大)来改进,Han等人于2002年尝试利用贪心法,针对文件分类实做可调整权重的k最近邻居法WAkNN (weighted adjusted k nearest neighbor),以促进分类效果;而Li等人于2004年提出由于不同分类的文件本身有数量上有差异,因此也应该依照训练集合中各种分类的文件数量,选取不同数目的最近邻居,来参与分类。
六、算法实现(java)
算法实现参考地址2。
package com.datamine.knn;
//KNN节点类,用来存储最近邻的K个元组的相关信息
public class KNNNode {
private int index; //元组标号
private double distance; //与测试元组之间的距离
private String c; //所属类别
public KNNNode(int index,double distance,String c){
this.index = index;
this.distance = distance;
this.c = c;
}
public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
public String getC() {
return c;
}
public void setC(String c) {
this.c = c;
}
}
package com.datamine.knn;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
public class KNN {
/**
* 设置优先队列的比较函数,距离越大,优先级越高
*/
private Comparator comparator = new Comparator() {
public int compare(KNNNode o1, KNNNode o2) {
if(o1.getDistance() >= o2.getDistance())
return -1;
else
return 1;
}
};
/**
* 获取K个不同的随机数
* @param k 随机数
* @param max 随机数的最大范围
* @return 随机数组
*/
public List getRandKnum(int k, int max){
List rand = new ArrayList(k);
for(int i = 0; i d1, List d2){
double distance = 0.0;
for(int i =0; i> datas,List testData,int k){
//维护一个大小为k的按距离由大到小的优先队列,用户存储最近邻训练元组
PriorityQueue pq = new PriorityQueue(k, comparator);
//随机从训练集中获取k个元组
List randNum = getRandKnum(k, datas.size());
for(int i = 0; i currData = datas.get(index); //随机得到相应的训练元组
String c = currData.get(currData.size()-1).toString(); //最后一个数为类别
KNNNode node = new KNNNode(index, calDistance(testData,currData),c);
pq.add(node);
}
/*
* 遍历训练元组集,计算训练元组和测试元组的距离
* 将所得距离distance和优先队列中的最大距离top比较
* 若top>distance,删除优先队列中最大距离top元组
* 将当前训练元组存入优先队列中
*/
for(int i = 0; i< datas.size();i++){
List t = datas.get(i);
double distance = calDistance(testData,t);
KNNNode top = pq.peek();
if(top.getDistance() > distance){
pq.remove();
pq.add(new KNNNode(i,distance,t.get(t.size()-1).toString()));
}
}
return getMostClass(pq);
}
/**
* 获取所得到的k个最近邻元组的多数类
* @param pq存储k个最近邻元组的优先级队列
* @return 多数类的名称
*/
private String getMostClass(PriorityQueue pq) {
//classCount用来存储列别名和对应的个数
Map classCount = new HashMap();
int pqsize = pq.size();
for(int i = 0;i maxCount){
maxIndex = i;
maxCount = classCount.get(classes[i]);
}
}
return classes[maxIndex].toString();
}
}
package com.datamine.knn;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class TestKNN {
/**
* 从数据文件中读取数据
* @param datas 存储数据的集合对象
* @param path 数据文件的路径
* @throws IOException
*/
public void read(List> datas,String path) throws IOException{
BufferedReader br = new BufferedReader(new FileReader(new File(path)));
String reader = br.readLine();
while(reader != null){
String t[] = reader.split(" ");
ArrayList list = new ArrayList();
for(int i = 0;i> datas = new ArrayList>();
List> testDatas = new ArrayList>();
t.read(datas, datafile);
t.read(testDatas, testfile);
KNN knn = new KNN();
for(int i = 0; i < testDatas.size() ;i++){
List test = testDatas.get(i);
System.out.print("测试元组:");
for(int j = 0; j < test.size();j++){
System.out.print(test.get(j) + " ");
}
System.out.print("类别为: ");
System.out.println(knn.knn(datas, test, 3));
}
}
}
七、实验数据与结果
训练数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
测试数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
实验结果:
测试元组:1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1.0
测试元组:1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1.0
测试元组:1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1.0
测试元组:1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0.0
测试元组:1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1.0
测试元组:1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0.0