机器学习实战学习笔记-KNN算法

1.KNN算法介绍。

KNN算法即k~近邻算法,通过计算测试数据与已知分类的样本数据集的相似度,选择相似度最高的前k条数据。统计k个数据中分类出现最高的分类,做为测试数据的分类。


2.算法特点

优点:精度高、对异常值不敏感 缺点:时间复杂度和空间复杂度高   适用数据:数据型和标称型


下面的相似度计算采用欧式距离:

两个n维向量x(x1,x2,...,xn),y(y1,y2,...,yn)的欧式距离公式 



用Java实现KNN算法

首先定义用于存储数据的结点Node类,包含三个属性,1.样本数据属性 2.距离 3.分类类型。Node实现了Comparable接口,方便后续的求最相似的k条数据记录;给出了calDistance方法,用于计算结点间的距离。

package KNN;

import java.util.List; 
public class Node implements Comparable { 
	private List data;
	private double distance;
	private int type;
	public Node() {
		super();
		// TODO Auto-generated constructor stub
	}
	public Node(List data, double distance, int type) {
		super();
		this.data = data;
		this.distance = distance;
		this.type = type;
	}
	public List getData() {
		return data;
	}
	public void setData(List data) {
		this.data = data;
	}
	public double getDistance() {
		return distance;
	}
	public void setDistance(double distance) {
		this.distance = distance;
	}
	public int getType() {
		return type;
	}
	public void setType(int type) {
		this.type = type;
	}
	public int compareTo(Node arg0) {
		if(this.distance>=arg0.getDistance())
			return 1;
		else
			return -1; 
	}
	
	public double calDistance(Node arg0){ 
		distance=0;
		for(int i=0;i


KNN类的实现:

 

package KNN;
  
import java.io.BufferedReader;
import java.io.FileInputStream; 
import java.io.FileNotFoundException;
import java.io.IOException; 
import java.io.InputStreamReader; 
import java.util.ArrayList;
import java.util.Collections; 
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
 
public class KNN {
	/*
	 * 参考数据集
	 */
	private List dataSet=null;
	
	/*
	 * 特征列最小值
	 */
	private double[] min; 
	
	/*
	 * 特征列取值范围
	 */
	private double[] range;
	
	/*
	 * 特征属性个数
	 */
	private int featNum;
	
	/*
	 * 样本个数
	 */
	private int sampNum;
	
	/*
	 * 获取前K个
	 */
	private int k;
	
	//通过给定的样本数据集路径初始化KNN分类器
	public KNN(String fileName,int k){
		dataSet=getDataSet(fileName);
		searchMaxMin(); 
		autoNorm(dataSet);
		this.k=k;
		if(k<=0 || k>sampNum)
			k=(int) (sampNum*0.5);

	}
	
	/*
	 * 获取参考数据集
	 */
	@SuppressWarnings("resource")
	public List getDataSet(String fileName){
		List list=new ArrayList();
		BufferedReader br;
		try {
			br = new BufferedReader(new InputStreamReader(new FileInputStream(fileName)));
			String line=null;
			Node node=null;
			List data=null;
			while((line=br.readLine())!=null){
				node=new Node();
				data=new ArrayList();
				String[] s=line.split("\t");  
				for(int i=0;idataSet.get(i).getData().get(j))
					min[j]=dataSet.get(i).getData().get(j);
				if(max[j] autoNorm(List l){ 
		for(int i=0;i testData){
		for(int i=0;i map=new HashMap(); 
		ValueComparetor vc=new ValueComparetor(map);
		TreeMap tm=new TreeMap(vc); 
		for(int i=0;i{
		Map map;
		public ValueComparetor(Map map ){
			this.map=map;
		}

		public int compare(Integer arg0, Integer arg1) {
			if(map.get(arg0)>=map.get(arg1))
				return -1;
			else
				return 1;
		} 
	}
	
 
	/*
	 * 对给定的测试文件进行分类测试
	 */
	public void classify(String testFile){
		List test=getDataSet(testFile); 
		test=autoNorm(test);
		calDistance(test);
		List oldData=getDataSet(testFile);
		int errorCount=0;
		for(int i=0;i

    KNN类属性包含有数据集dataSet,数据集属性列最小值min,数据集属性列的取值范围range,数据集属性个数featNum,数据集数据条数sanpNum,k值。函数getDataSet()读取指定路径的文件,录入数据。对于分类数据,不同属性的取值范围会不同,通过原数据直接求欧式距离,属性取值范围大的对计算结果的影响大。在权重未知的情况下,需要对数据做归一化处理,是各属性权重相同。函数autoNorm,通过:

                       newValue = {oldValue-min)/ (max-min) 

公式,将属性映射到(0,1)。函数searchMaxMin()将遍历数据集,找出每一属性的最大值和最小值,并求出各属性值的取值范围。函数calDistance(),计算一条测试数据与样本数据集的每条数据的距离,递增排序。通过getMost()函数,查找出前k条数据集中出现次数最多的分类。classify()函数读取一个文件的数据作为测试数据进行测试,输出分类结果,并统计错误分类的个数,输出错误率。


编写测试类。

 

package KNN;

public class TestKNN {
	public static void main(String[] args) {
		String sampFile="E:\\Java_Project\\DeepLearning\\src\\KNN\\datingTestSet2.txt";
		String testFile="E:\\Java_Project\\DeepLearning\\src\\KNN\\test.txt";
		KNN knn=new KNN(sampFile,4);
		knn.classify(testFile);
		
	}

}

数据格式如下,最后一列为分类:

 

40920	8.326976	0.953952	3
14488	7.153469	1.673904	2
26052	1.441871	0.805124	1
75136	13.147394	0.428964	1
38344	1.669788	0.134296	1
72993	10.141740	1.032955	1
35948	6.830792	1.213192	3
42666	13.276369	0.543880	3
67497	8.631577	0.749278	1

总结:K近邻算法是最简单有效的分类算法,但是分类时需要计算每条测试数据与所有样本数据的相似度,对于样本数据量大的情况,需要大量存储空间并且耗时。


你可能感兴趣的:(机器学习笔记-Java实现)