JAVA实现KNN分类

转载请注明出处:http://blog.csdn.net/xiaojimanman/article/details/51064307

http://www.llwjy.com/blogdetail/f74b497c2ad6261b0ea651454b97a390.html

个人博客站已经上线了,网址 www.llwjy.com ~欢迎各位吐槽~

-------------------------------------------------------------------------------------------------

      在开始之前先打一个小小的广告,自己创建一个QQ群:321903218,点击链接加入群【Lucene案例开发】,主要用于交流如何使用Lucene来创建站内搜索后台,同时还会不定期的在群内开相关的公开课,感兴趣的童鞋可以加入交流。


      KNN算法又叫近邻算法,是数据挖掘中一种常用的分类算法,接单的介绍KNN算法的核心思想就是:寻找与目标最近的K个个体,这些样本属于类别最多的那个类别就是目标的类别。比如K为7,那么我们就从数据中找到和目标最近(或者相似度最高)的7个样本,加入这7个样本对应的类别分别为A、B、C、A、A、A、B,那么目标属于的分类就是A(因为这7个样本中属于A类别的样本个数最多)。


算法实现

一、训练数据格式定义

      下面就简单的介绍下如何用JAVA来实现KNN分类,首先我们需要存储训练集(包括属性以及对应的类别),这里我们对未知的属性使用泛型,类别我们使用字符串存储。

 /**  
 *@Description:  KNN分类模型中一条记录的存储格式
 */ 
package com.lulei.datamining.knn.bean;  
  
public class KnnValueBean<T>{
	private T value;//记录值
	private String typeId;//分类ID
	
	public KnnValueBean(T value, String typeId) {
		this.value = value;
		this.typeId = typeId;
	}

	public T getValue() {
		return value;
	}

	public void setValue(T value) {
		this.value = value;
	}

	public String getTypeId() {
		return typeId;
	}

	public void setTypeId(String typeId) {
		this.typeId = typeId;
	}
}

二、K个最近邻类别数据格式定义

      在统计得到K个最近邻中,我们需要记录前K个样本的分类以及对应的相似度,我们这里使用如下数据格式:

 /**  
 *@Description: K个最近邻的类别得分
 */ 
package com.lulei.datamining.knn.bean;  
  
public class KnnValueSort {
	private String typeId;//分类ID
	private double score;//该分类得分
	
	public KnnValueSort(String typeId, double score) {
		this.typeId = typeId;
		this.score = score;
	}
	public String getTypeId() {
		return typeId;
	}
	public void setTypeId(String typeId) {
		this.typeId = typeId;
	}
	public double getScore() {
		return score;
	}
	public void setScore(double score) {
		this.score = score;
	}
}

三、KNN算法基本属性

      在KNN算法中,最重要的一个指标就是K的取值,因此我们在基类中需要设置一个属性K以及设置一个数组用于存储已知分类的数据。

private List<KnnValueBean> dataArray;
private int K = 3;

四、添加已知分类数据

      在使用KNN分类之前,我们需要先向其中添加我们已知分类的数据,我们后面就是使用这些数据来预测未知数据的分类。

/**
 * @param value
 * @param typeId
 * @Author:lulei  
 * @Description: 向模型中添加记录
 */
public void addRecord(T value, String typeId) {
	if (dataArray == null) {
		dataArray = new ArrayList<KnnValueBean>();
	}
	dataArray.add(new KnnValueBean<T>(value, typeId));
}

五、两个样本之间的相似度(或者距离)

      在KNN算法中,最重要的一个方法就是如何确定两个样本之间的相似度(或者距离),由于这里我们使用的是泛型,并没有办法确定两个对象之间的相似度,一次这里我们把它设置为抽象方法,让子类来实现。这里我们方法定义为相似度,也就是返回值越大,两者越相似,之间的距离越短

/**
 * @param o1
 * @param o2
 * @return
 * @Author:lulei  
 * @Description: o1 o2之间的相似度
 */
public abstract double similarScore(T o1, T o2);

六、获取最近的K个样本的分类

      KNN算法的核心思想就是找到最近的K个近邻,因此这一步也是整个算法的核心部分。这里我们使用数组来保存相似度最大的K个样本的分类和相似度,在计算的过程中通过循环遍历所有的样本,数组保存截至当前计算点最相似的K个样本对应的类别和相似度,具体实现如下:

/**
 * @param value
 * @return
 * @Author:lulei  
 * @Description: 获取距离最近的K个分类
 */
private KnnValueSort[] getKType(T value) {
	int k = 0;
	KnnValueSort[] topK = new KnnValueSort[K];
	for (KnnValueBean<T> bean : dataArray) {
		double score = similarScore(bean.getValue(), value);
		if (k == 0) {
			//数组中的记录个数为0是直接添加
			topK[k] = new KnnValueSort(bean.getTypeId(), score);
			k++;
		} else {
			if (!(k == K && score < topK[k -1].getScore())){
				int i = 0;
				//找到要插入的点
				for (; i < k && score < topK[i].getScore(); i++);
				int j = k - 1;
				if (k < K) {
					j = k;
					k++;
				}
				for (; j > i; j--) {
					topK[j] = topK[j - 1];
				}
				topK[i] = new KnnValueSort(bean.getTypeId(), score);
			}
		}
	}
	return topK;
}

七、统计K个样本出现次数最多的类别

      这一步就是一个简单的计数,统计K个样本中出现次数最多的分类,该分类就是我们要预测的目标数据的分类。

/**
 * @param value
 * @return
 * @Author:lulei  
 * @Description: KNN分类判断value的类别
 */
public String getTypeId(T value) {
	KnnValueSort[] array = getKType(value);
	HashMap<String, Integer> map = new HashMap<String, Integer>(K);
	for (KnnValueSort bean : array) {
		if (bean != null) {
			if (map.containsKey(bean.getTypeId())) {
				map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);
			} else {
				map.put(bean.getTypeId(), 1);
			}
		}
	}
	String maxTypeId = null;
	int maxCount = 0;
	Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();
	while (iter.hasNext()) {
		Entry<String, Integer> entry = iter.next();
		if (maxCount < entry.getValue()) {
			maxCount = entry.getValue();
			maxTypeId = entry.getKey();
		}
	}
	return maxTypeId;
}

      到现在为止KNN分类的抽象基类已经编写完成,在测试之前我们先多说几句,KNN分类是统计K个样本中出现次数最多的分类,这种在有些情况下并不是特别合理,比如K=5,前5个样本对应的分类分别为A、A、B、B、B,对应的相似度得分分别为10、9、2、2、1,如果使用上面的方法,那预测的分类就是B,但是看这些数据,预测的分类是A感觉更合理。基于这种情况,自己对KNN算法提出如下优化(这里并不提供代码,只提供简单的思路):在获取最相似K个样本和相似度后,可以对相似度和出现次数K做一种函数运算,比如加权,得到的函数值最大的分类就是目标的预测分类。

基类源码

 /**  
 *@Description: KNN分类
 */ 
package com.lulei.datamining.knn;  

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;

import com.lulei.datamining.knn.bean.KnnValueBean;
import com.lulei.datamining.knn.bean.KnnValueSort;
import com.lulei.util.JsonUtil;
  
@SuppressWarnings({"rawtypes"})
public abstract class KnnClassification<T> {
	private List<KnnValueBean> dataArray;
	private int K = 3;
	
	public int getK() {
		return K;
	}
	public void setK(int K) {
		if (K < 1) {
			throw new IllegalArgumentException("K must greater than 0");
		}
		this.K = K;
	}

	/**
	 * @param value
	 * @param typeId
	 * @Author:lulei  
	 * @Description: 向模型中添加记录
	 */
	public void addRecord(T value, String typeId) {
		if (dataArray == null) {
			dataArray = new ArrayList<KnnValueBean>();
		}
		dataArray.add(new KnnValueBean<T>(value, typeId));
	}
	
	/**
	 * @param value
	 * @return
	 * @Author:lulei  
	 * @Description: KNN分类判断value的类别
	 */
	public String getTypeId(T value) {
		KnnValueSort[] array = getKType(value);
		System.out.println(JsonUtil.parseJson(array));
		HashMap<String, Integer> map = new HashMap<String, Integer>(K);
		for (KnnValueSort bean : array) {
			if (bean != null) {
				if (map.containsKey(bean.getTypeId())) {
					map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);
				} else {
					map.put(bean.getTypeId(), 1);
				}
			}
		}
		String maxTypeId = null;
		int maxCount = 0;
		Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();
		while (iter.hasNext()) {
			Entry<String, Integer> entry = iter.next();
			if (maxCount < entry.getValue()) {
				maxCount = entry.getValue();
				maxTypeId = entry.getKey();
			}
		}
		return maxTypeId;
	}
	
	/**
	 * @param value
	 * @return
	 * @Author:lulei  
	 * @Description: 获取距离最近的K个分类
	 */
	private KnnValueSort[] getKType(T value) {
		int k = 0;
		KnnValueSort[] topK = new KnnValueSort[K];
		for (KnnValueBean<T> bean : dataArray) {
			double score = similarScore(bean.getValue(), value);
			if (k == 0) {
				//数组中的记录个数为0是直接添加
				topK[k] = new KnnValueSort(bean.getTypeId(), score);
				k++;
			} else {
				if (!(k == K && score < topK[k -1].getScore())){
					int i = 0;
					//找到要插入的点
					for (; i < k && score < topK[i].getScore(); i++);
					int j = k - 1;
					if (k < K) {
						j = k;
						k++;
					}
					for (; j > i; j--) {
						topK[j] = topK[j - 1];
					}
					topK[i] = new KnnValueSort(bean.getTypeId(), score);
				}
			}
		}
		return topK;
	}
	
	/**
	 * @param o1
	 * @param o2
	 * @return
	 * @Author:lulei  
	 * @Description: o1 o2之间的相似度
	 */
	public abstract double similarScore(T o1, T o2);
}

具体子类实现

      对于上面介绍的都在KNN分类的抽象基类中,对于实际的问题我们需要继承基类并实现基类中的相似度抽象方法,这里我们做一个简单的实现。

 /**  
 *@Description:     
 */ 
package com.lulei.datamining.knn.test;  

import com.lulei.datamining.knn.KnnClassification;
import com.lulei.util.JsonUtil;
  
public class Test extends KnnClassification<Integer>{
	
	@Override
	public double similarScore(Integer o1, Integer o2) {
		return -1 * Math.abs(o1 - o2);
	}
	
	/**  
	 * @param args
	 * @Author:lulei  
	 * @Description:  
	 */
	public static void main(String[] args) {
		Test test = new Test();
		for (int i = 1; i < 10; i++) {
			test.addRecord(i, i > 5 ? "0" : "1");
		}
		System.out.println(JsonUtil.parseJson(test.getTypeId(0)));
		
	}
}


      这里我们一共添加了1、2、3、4、5、6、7、8、9这9组数据,前5组的类别为1,后4组的类别为0,两个数据之间的相似度为两者之间的差值的绝对值的相反数,下面预测0应该属于的分类,这里K的默认值为3,因此最近的K个样本分别为1、2、3,对应的分类分别为"1"、"1"、"1",因为最后预测的分类为"1"。

-------------------------------------------------------------------------------------------------
小福利
-------------------------------------------------------------------------------------------------
      个人在极客学院上《Lucene案例开发》课程已经上线了,欢迎大家吐槽~

第一课:Lucene概述

第二课:Lucene 常用功能介绍

第三课:网络爬虫

第四课:数据库连接池

第五课:小说网站的采集

第六课:小说网站数据库操作

第七课:小说网站分布式爬虫的实现

第八课:Lucene实时搜索


你可能感兴趣的:(java,算法,分类,knn)