1.KNN算法原理:
(1)基于类比原理,通过比较训练元组和测试元组的相似度来学习
(2)将训练元组和测试元组看作是n维空间内的点,给定一条测试元组,搜索n维空间,找出与测试元组最相近的k个点(即训练元组),最后取这k个点中的多数类做作为测试元组 的类别。
(3)相近的度量方法:用空间内两个点的距离来度量,距离越大,表示两个点越不相似。
(4)距离的选择:可采用欧几里得距离,曼哈顿距离,等其他度量方法,一般采用欧几里得距离,比较简单。
2.KNN算法中的细节处理
(1)数值属性规范化:将数值属性规范到0-1区间以便于计算,也可防止大数值型属性对分类的主导。
(2)可选的方法有:v‘=(v-vmin)/(vmax-vmin),也有其他方法。
(3)比较的数据是分类类型而不是数值类型,同则差为0,异则差为1,有时候可以做更精确的处理,比如黑色和白色的差肯定大于灰色和白色的差。
(4)缺失值的处理(不理解,以后理解):取最大的可能差,对于分类属性,如果属性A的一个或两个对应值丢失,则取差值为1;如果A是数值属性,若两个比较的元组A属性值均缺失,则取差值为1,若只有一个缺失,另一个值为v,则取差值为|1-v|和|0-v|中的最大值。
(5)确定K的值:通过实验确定。进行若干次实验,取分类误差率最小的k值。
(7)对噪声数据或不相关属性的处理:对属性赋予相关性权重w,w越大说明属性对分类的影响越相关。对噪声数据可以将所在的元组直接cut掉。
3.KNN算法流程
1)准备数据,对数据进行预处理
2)选用合适的数据结构存储训练数据和测试元组
3)设定参数,如k
4)维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元 组的距离,将训练元组标号和距离存入优先级队列
5)遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L与优先级队列中的最大距离Lmax进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
6)遍历完毕,计算优先级队列中k个元组的多数类,并将其作为测试元组的类别。
7)测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k值。
4.Java实现
KNN结点类,用来存储最近邻的k个元组相关的信息
/**
* KNN结点类,用来存储最近邻的k个元组相关的信息
*/
public class KNNNode {
private int index; // 元组标号
private double distance; // 与测试元组的距离
private String c; // 所属类别
public KNNNode(int index, double distance, String c) {
super();
this.index = index;
this.distance = distance;
this.c = c;
}
public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
public String getC() {
return c;
}
public void setC(String c) {
this.c = c;
}
}
KNN算法主体类
/**
* KNN算法主体类
*/
public class KNN {
/**
* 设置优先级队列的比较函数,距离越大,优先级越高
*/
private Comparator comparator = new Comparator() {
public int compare(KNNNode o1, KNNNode o2) {
if (o1.getDistance() >= o2.getDistance()) {
return 1;
} else {
return 0;
}
}
};
/**
* 获取K个不同的随机数
* @param k 随机数的个数
* @param max 随机数最大的范围
* @return 生成的随机数数组
*/
public List getRandKNum(int k, int max) {
List rand = new ArrayList(k);
for (int i = 0; i < k; i++) {
int temp = (int) (Math.random() * max);
if (!rand.contains(temp)) {
rand.add(temp);
} else {
i--;
}
}
return rand;
}
/**
* 计算测试元组与训练元组之前的距离
* @param d1 测试元组
* @param d2 训练元组
* @return 距离值
*/
public double calDistance(List d1, List d2) {
double distance = 0.00;
for (int i = 0; i < d1.size(); i++) {
distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
}
return distance;
}
/**
* 执行KNN算法,获取测试元组的类别
* @param datas 训练数据集
* @param testData 测试元组
* @param k 设定的K值
* @return 测试元组的类别
*/
public String knn(List> datas, List testData, int k) {
PriorityQueue pq = new PriorityQueue(k, comparator);
List randNum = getRandKNum(k, datas.size());
for (int i = 0; i < k; i++) {
int index = randNum.get(i);
List currData = datas.get(index);
String c = currData.get(currData.size() - 1).toString();
KNNNode node = new KNNNode(index, calDistance(testData, currData), c);
pq.add(node);
}
for (int i = 0; i < datas.size(); i++) {
List t = datas.get(i);
double distance = calDistance(testData, t);
KNNNode top = pq.peek();
if (top.getDistance() > distance) {
pq.remove();
pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));
}
}
return getMostClass(pq);
}
/**
* 获取所得到的k个最近邻元组的多数类
* @param pq 存储k个最近近邻元组的优先级队列
* @return 多数类的名称
*/
private String getMostClass(PriorityQueue pq) {
Map classCount = new HashMap();
for (int i = 0; i < pq.size(); i++) {
KNNNode node = pq.remove();
String c = node.getC();
if (classCount.containsKey(c)) {
classCount.put(c, classCount.get(c) + 1);
} else {
classCount.put(c, 1);
}
}
int maxIndex = -1;
int maxCount = 0;
Object[] classes = classCount.keySet().toArray();
for (int i = 0; i < classes.length; i++) {
if (classCount.get(classes[i]) > maxCount) {
maxIndex = i;
maxCount = classCount.get(classes[i]);
}
}
return classes[maxIndex].toString();
}
}
KNN算法测试类
package feature;
import java.io.BufferedReader;
import java.io.*;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
// KNN算法测试类
public class TestKNN {
/**
* 从数据文件中读取数据
*
* @param datas
* 存储数据的集合对象
* @param path
* 数据文件的路径
*/
public void read(List> datas, String path) {
try {
BufferedReader br = new BufferedReader(new FileReader(
new File(path)));
String reader = br.readLine();
while (reader != null) {
String t[] = reader.split("\t");
ArrayList list = new ArrayList();
for (int i = 0; i < t.length; i++) {
list.add(Double.parseDouble(t[i]));
}
datas.add(list);
reader = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 程序执行入口
*
* @param args
*/
public static void main(String[] args) {
TestKNN t = new TestKNN();
String datafile="e:\\work\\交接数据\\testData\\KNN\\datafile.txt";
String testfile="e:\\work\\交接数据\\testData\\KNN\\testfile.txt";
try {
List> datas = new ArrayList>();
List> testDatas = new ArrayList>();
t.read(datas, datafile);
t.read(testDatas, testfile);
KNN knn = new KNN();
FileWriter fw=new FileWriter("e:\\work\\交接数据\\testData\\KNN\\result.txt");
for (int i = 0; i < testDatas.size(); i++) {
List test = testDatas.get(i);
System.out.print("测试元组: ");
for (int j = 0; j < test.size(); j++) {
fw.write(test.get(j) + "\t");
fw.flush();
System.out.print(test.get(j) + "\t");
}
System.out.print("类别为: ");
fw.write(Math.round(Float.parseFloat((knn.knn(datas,
test, 3))))+"\n");
fw.flush();
System.out.println(Math.round(Float.parseFloat((knn.knn(datas,
test, 3)))));
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
训练数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
测试数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
结果是:
测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1
测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1
测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1
测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0
测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1
测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0
5.错误及解决办法
在测试KNN算法时出现错误
java.lang.NumberFormatException: For input string: "1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0"
at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1224)
at java.lang.Double.parseDouble(Double.java:510)
at feature.TestKNN.read(TestKNN.java:29)
at feature.TestKNN.main(TestKNN.java:53)
java.lang.NumberFormatException: For input string: "1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5"
at sun.misc.FloatingDecimal.readJavaFormatString(FloatingDecimal.java:1224)
at java.lang.Double.parseDouble(Double.java:510)
at feature.TestKNN.read(TestKNN.java:29)
at feature.TestKNN.main(TestKNN.java:54)
错误原因是发现read方法中出错了
public void read(List> datas, String path) {
try {
BufferedReader br = new BufferedReader(new FileReader(
new File(path)));
String reader = br.readLine();
while (reader != null) {
String t[] = reader.split("\t");
ArrayList list = new ArrayList();
for (int i = 0; i < t.length; i++) {
list.add(Double.parseDouble(t[i]));
}
datas.add(list);
reader = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
}
}
出错行是list.add(Double.parseDouble(t[i]));
取最大的可能差,对于分类属性,如果属性
A
的一个或两个对应值丢失,则取差值为
1
;
如果
A
是数值属性,若两个比较的元组
A
属性值均缺失,则取差值为
1
,若只有一个缺失,另一个值为
v
,
则取差值为|
1-v
|和|
0-v
|中的最大值