knn算法java实现_KNN算法java实现代码注释

K近邻算法思想非常简单,总结起来就是根据某种距离度量检测未知数据与已知数据的距离,统计其中距离最近的k个已知数据的类别,以多数投票的形式确定未知数据的类别。

一直想自己实现knn的java实现,但限于自己的编程水平,java刚刚入门,所以就广泛搜索网上以实现的java代码来研习。下面这个简单的knn算法的java实现是在这篇博客中找到的:http://blog.csdn.net/luowen3405/article/details/6278764

下面给出我对代码的注释,如果有错误请指正。

源程序共定义了三个class文件,分别是:public class KNNNode;public class KNN;public class TestKNN。

Description:

KNNNode: KNN结点类,用来存储最近邻的k个元组相关的信息

KNN:      KNN算法主体类

TestKNN: KNN算法测试类

首先,按照程序执行顺序依次解释class的思想。

1、 TestKNN

Method: public void read()

读取文件中的数据,存储为数组的形式(以嵌套链表的形式实现)List> datas

程序主体执行:main

首先读入训练数据文件和测试数据文件的数据,然后输出测试数据的类别。此程序中K=3,根据对这个数据集的了解,k=3时效果是最好的。Knn算法k的确定一直是一个值得研究的problem。

2、 算法主体:KNN

此程序中比较一个难点是作者定义了一个大小为k优先级队列来存储k个最近邻节点。优先级队列初始默认是距离越远越优先,然后根据算法中的实现,将与测试集最近的k个节点保存下来。

3、 定义了一个数据节点数据结构:KNNNode

源码如下:

packageKNN;importjava.util.ArrayList;importjava.util.Comparator;importjava.util.HashMap;importjava.util.List;importjava.util.Map;importjava.util.PriorityQueue;/*** KNN算法主体类

*@authorRowen

* @qq 443773264

* @mail [email protected]

* @blog blog.csdn.net/luowen3405

* @data 2011.03.25*/

public classKNN {/*** 设置优先级队列的比较函数,距离越大,优先级越高*/

private Comparator comparator = new Comparator() {public intcompare(KNNNode o1, KNNNode o2) {if (o1.getDistance() >=o2.getDistance()) {return 1;

}else{return 0;

}

}

};/*** 获取K个不同的随机数

*@paramk 随机数的个数

*@parammax 随机数最大的范围

*@return生成的随机数数组*/

public List getRandKNum(int k, intmax) {

List rand = new ArrayList(k);for (int i = 0; i < k; i++) {int temp = (int) (Math.random() *max);if (!rand.contains(temp)) {

rand.add(temp);

}else{

i--;

}

}returnrand;

}/*** 计算测试元组与训练元组之前的距离

*@paramd1 测试元组

*@paramd2 训练元组

*@return距离值*/

public double calDistance(List d1, Listd2) {double distance = 0.00;for (int i = 0; i < d1.size(); i++) {

distance+= (d1.get(i) - d2.get(i)) * (d1.get(i) -d2.get(i));

}returndistance;

}/*** 执行KNN算法,获取测试元组的类别

*@paramdatas 训练数据集

*@paramtestData 测试元组

*@paramk 设定的K值

*@return测试元组的类别*/

public String knn(List> datas, List testData, intk) {

PriorityQueue pq = new PriorityQueue(k, comparator);//按照自然顺序存储容量为k的优先级队列

List randNum = getRandKNum(k, datas.size()); //建立一个列表,列表中保存的是训练数据集中实例的个数//计算当前一个测试数据实例与训练数据集的距离,并按照距离来排序

for (int i = 0; i < k; i++) {int index =randNum.get(i);

List currData =datas.get(index);

String c= currData.get(currData.size() - 1).toString();

KNNNode node= newKNNNode(index, calDistance(testData, currData), c);

pq.add(node);//System.out.println("距离"+node.getDistance()+"测试样例"+index+"k值"+k);

}//统计与测试实例距离最近的数据,然后将

for (int i = 0; i < datas.size(); i++) {

List t =datas.get(i);double distance =calDistance(testData, t);

KNNNode top=pq.peek();if (top.getDistance() >distance) {

pq.remove();

pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));

}

}returngetMostClass(pq);

}/*** 获取所得到的k个最近邻元组的多数类

*@parampq 存储k个最近近邻元组的优先级队列

*@return多数类的名称*/

private String getMostClass(PriorityQueuepq) {

Map classCount = new HashMap();for (int i = 0; i < pq.size(); i++) {

KNNNode node=pq.remove();

String c=node.getC();if(classCount.containsKey(c)) {

classCount.put(c, classCount.get(c)+ 1);

}else{

classCount.put(c,1);

}

}int maxIndex = -1;int maxCount = 0;

Object[] classes=classCount.keySet().toArray();for (int i = 0; i < classes.length; i++) {if (classCount.get(classes[i]) >maxCount) {

maxIndex=i;

maxCount=classCount.get(classes[i]);

}

}returnclasses[maxIndex].toString();

}

}

packageKNN;/*** KNN结点类,用来存储最近邻的k个元组相关的信息

*@authorRowen

* @qq 443773264

* @mail [email protected]

* @blog blog.csdn.net/luowen3405

* @data 2011.03.25*/

public classKNNNode {private int index; //元组标号

private double distance; //与测试元组的距离

private String c; //所属类别

public KNNNode(int index, doubledistance, String c) {super();this.index =index;this.distance =distance;this.c =c;

}public intgetIndex() {returnindex;

}public void setIndex(intindex) {this.index =index;

}public doublegetDistance() {returndistance;

}public void setDistance(doubledistance) {this.distance =distance;

}publicString getC() {returnc;

}public voidsetC(String c) {this.c =c;

}

}

packageKNN;importjava.io.BufferedReader;importjava.io.File;importjava.io.FileReader;importjava.util.ArrayList;importjava.util.List;/*** KNN算法测试类

*@authorRowen

* @qq 443773264

* @mail [email protected]

* @blog blog.csdn.net/luowen3405

* @data 2011.03.25*/

public classTestKNN {/*** 从数据文件中读取数据

*@paramdatas 存储数据的集合对象

*@parampath 数据文件的路径*/

public void read(List>datas, String path){try{

BufferedReader br= new BufferedReader(new FileReader(newFile(path)));

String data=br.readLine();

List l = null;while (data != null) {

String t[]= data.split(" ");

l= new ArrayList();for (int i = 0; i < t.length; i++) {

l.add(Double.parseDouble(t[i]));//System.out.println(l);

}

datas.add(l);

data=br.readLine();

}

br.close();

}catch(Exception e) {

e.printStackTrace();

}

}/*** 程序执行入口

*@paramargs*/

public static voidmain(String[] args) {

TestKNN t= newTestKNN();

String datafile= new File("").getAbsolutePath() + File.separator + "datafile";

String testfile= new File("").getAbsolutePath() + File.separator + "testfile";//System.out.println(datafile);

try{

List> datas = new ArrayList>();

List> testDatas = new ArrayList>();

t.read(datas, datafile);

t.read(testDatas, testfile);//System.out.println(datas);

KNN knn = newKNN();for (int i = 0; i < testDatas.size(); i++) {

List test =testDatas.get(i);

System.out.print("测试元组: ");for (int j = 0; j < test.size(); j++) {

System.out.print(test.get(j)+ " ");

}

System.out.print("类别为: ");

System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test,3)))));

}

}catch(Exception e) {

e.printStackTrace();

}

}

}

附上待分类数据:

文件名字:datafile

0.1887 0.3276 -1

0.8178 0.7703 1

0.6761 0.4849 -1

0.6022 0.6878 -1

0.1759 0.8217 -1

0.2607 0.3502 1

0.2875 0.6713 -1

0.916 0.7363 -1

0.1615 0.2564 1

0.2653 0.9452 1

0.0911 0.4386 -1

0.0012 0.3947 -1

0.4253 0.8419 1

0.0067 0.4424 -1

0.8244 0.2089 1

0.3868 0.3592 -1

0.9174 0.216 -1

0.6074 0.3968 -1

0.068 0.5201 -1

0.9686 0.9937 1

0.0908 0.3658 1

0.3411 0.7691 -1

0.4609 0.4423 -1

0.1078 0.4501 1

0.3445 0.0445 -1

0.9827 0.7093 1

0.2428 0.3774 -1

0.0358 0.1971 -1

0.82 0.721 1

0.6718 0.6714 -1

0.6753 0.2428 -1

0.7218 0.4299 -1

0.3127 0.8329 1

0.0225 0.4162 1

0.5313 0.2187 1

0.7847 0.4243 -1

0.2518 0.6476 1

0.4076 0.5439 1

0.9063 0.4587 1

0.4714 0.2703 -1

0.7702 0.0196 -1

0.2548 0.3477 -1

0.0942 0.5407 1

0.1917 0.8085 -1

0.6834 0.7689 -1

0.1056 0.1097 1

0.9577 0.5303 -1

0.9436 0.0938 -1

0.6959 0.3181 1

0.4235 0.4484 1

0.6171 0.6358 1

0.5309 0.5447 1

0.8444 0.2621 -1

0.5762 0.8335 -1

0.281 0.772 1

0.224 0.15 -1

0.4243 0.704 -1

0.7384 0.7551 -1

0.4401 0.9329 1

0.2665 0.7635 1

0.5944 0.662 1

0.3225 0.3309 -1

0.4709 0.2648 1

0.6444 0.9899 -1

0.5271 0.9727 1

0.7788 0.4046 1

0.7302 0.2362 1

0.5181 0.6963 -1

0.5841 0.6073 1

0.7184 0.5225 1

0.6999 0.1192 1

0.3439 0.1194 1

0.6951 0.7413 -1

0.611 0.0636 1

0.4229 0.5822 1

0.4735 0.8878 -1

0.2891 0.3935 -1

0.3196 0.6393 1

0.1527 0.3912 -1

0.6385 0.9398 1

0.2904 0.679 1

0.4574 0.192 1

0.3251 0.1058 1

0.6377 0.5254 -1

0.5985 0.8699 1

0.4257 0.862 -1

0.2691 0.7904 -1

0.8754 0.1389 1

0.0336 0.6456 1

0.6544 0.6473 1

文件名称:testfile

0.9516 0.0326

0.9203 0.5612

0.0527 0.8819

0.7379 0.6692

0.2691 0.1904

0.4228 0.3689

0.5479 0.4607

0.9427 0.9816

0.4177 0.1564

0.9831 0.8555

0.3015 0.6448

0.7011 0.3763

0.6663 0.1909

0.5391 0.4283

0.6981 0.4820

0.6665 0.1206

0.1781 0.5895

0.1280 0.2262

0.9991 0.3846

0.1711 0.5830

通过KNN算法对未知数据集分类,设置k=3,分类结果如下:

测试元组: 0.9516 0.0326 类别为: -1测试元组:0.9203 0.5612 类别为: -1测试元组:0.0527 0.8819 类别为: -1测试元组:0.7379 0.6692 类别为: -1测试元组:0.2691 0.1904 类别为: -1测试元组:0.4228 0.3689 类别为: -1测试元组:0.5479 0.4607 类别为: -1测试元组:0.9427 0.9816 类别为: 1测试元组:0.4177 0.1564 类别为: 1测试元组:0.9831 0.8555 类别为: -1测试元组:0.3015 0.6448 类别为: -1测试元组:0.7011 0.3763 类别为: -1测试元组:0.6663 0.1909 类别为: -1测试元组:0.5391 0.4283 类别为: -1测试元组:0.6981 0.482 类别为: -1测试元组:0.6665 0.1206 类别为: 1测试元组:0.1781 0.5895 类别为: 1测试元组:0.128 0.2262 类别为: 1测试元组:0.9991 0.3846 类别为: -1测试元组:0.1711 0.583 类别为: 1

你可能感兴趣的:(knn算法java实现)