转载请注明出处:http://blog.csdn.net/xiaojimanman/article/details/51064307
http://www.llwjy.com/blogdetail/f74b497c2ad6261b0ea651454b97a390.html
个人博客站已经上线了,网址 www.llwjy.com ~欢迎各位吐槽~
-------------------------------------------------------------------------------------------------
在开始之前先打一个小小的广告,自己创建一个QQ群:321903218,点击链接加入群【Lucene案例开发】,主要用于交流如何使用Lucene来创建站内搜索后台,同时还会不定期的在群内开相关的公开课,感兴趣的童鞋可以加入交流。
KNN算法又叫近邻算法,是数据挖掘中一种常用的分类算法,接单的介绍KNN算法的核心思想就是:寻找与目标最近的K个个体,这些样本属于类别最多的那个类别就是目标的类别。比如K为7,那么我们就从数据中找到和目标最近(或者相似度最高)的7个样本,加入这7个样本对应的类别分别为A、B、C、A、A、A、B,那么目标属于的分类就是A(因为这7个样本中属于A类别的样本个数最多)。
算法实现
一、训练数据格式定义
下面就简单的介绍下如何用JAVA来实现KNN分类,首先我们需要存储训练集(包括属性以及对应的类别),这里我们对未知的属性使用泛型,类别我们使用字符串存储。
/** *@Description: KNN分类模型中一条记录的存储格式 */ package com.lulei.datamining.knn.bean; public class KnnValueBean<T>{ private T value;//记录值 private String typeId;//分类ID public KnnValueBean(T value, String typeId) { this.value = value; this.typeId = typeId; } public T getValue() { return value; } public void setValue(T value) { this.value = value; } public String getTypeId() { return typeId; } public void setTypeId(String typeId) { this.typeId = typeId; } }
在统计得到K个最近邻中,我们需要记录前K个样本的分类以及对应的相似度,我们这里使用如下数据格式:
/** *@Description: K个最近邻的类别得分 */ package com.lulei.datamining.knn.bean; public class KnnValueSort { private String typeId;//分类ID private double score;//该分类得分 public KnnValueSort(String typeId, double score) { this.typeId = typeId; this.score = score; } public String getTypeId() { return typeId; } public void setTypeId(String typeId) { this.typeId = typeId; } public double getScore() { return score; } public void setScore(double score) { this.score = score; } }
在KNN算法中,最重要的一个指标就是K的取值,因此我们在基类中需要设置一个属性K以及设置一个数组用于存储已知分类的数据。
private List<KnnValueBean> dataArray; private int K = 3;
在使用KNN分类之前,我们需要先向其中添加我们已知分类的数据,我们后面就是使用这些数据来预测未知数据的分类。
/** * @param value * @param typeId * @Author:lulei * @Description: 向模型中添加记录 */ public void addRecord(T value, String typeId) { if (dataArray == null) { dataArray = new ArrayList<KnnValueBean>(); } dataArray.add(new KnnValueBean<T>(value, typeId)); }
在KNN算法中,最重要的一个方法就是如何确定两个样本之间的相似度(或者距离),由于这里我们使用的是泛型,并没有办法确定两个对象之间的相似度,一次这里我们把它设置为抽象方法,让子类来实现。这里我们方法定义为相似度,也就是返回值越大,两者越相似,之间的距离越短。
/** * @param o1 * @param o2 * @return * @Author:lulei * @Description: o1 o2之间的相似度 */ public abstract double similarScore(T o1, T o2);
KNN算法的核心思想就是找到最近的K个近邻,因此这一步也是整个算法的核心部分。这里我们使用数组来保存相似度最大的K个样本的分类和相似度,在计算的过程中通过循环遍历所有的样本,数组保存截至当前计算点最相似的K个样本对应的类别和相似度,具体实现如下:
/** * @param value * @return * @Author:lulei * @Description: 获取距离最近的K个分类 */ private KnnValueSort[] getKType(T value) { int k = 0; KnnValueSort[] topK = new KnnValueSort[K]; for (KnnValueBean<T> bean : dataArray) { double score = similarScore(bean.getValue(), value); if (k == 0) { //数组中的记录个数为0是直接添加 topK[k] = new KnnValueSort(bean.getTypeId(), score); k++; } else { if (!(k == K && score < topK[k -1].getScore())){ int i = 0; //找到要插入的点 for (; i < k && score < topK[i].getScore(); i++); int j = k - 1; if (k < K) { j = k; k++; } for (; j > i; j--) { topK[j] = topK[j - 1]; } topK[i] = new KnnValueSort(bean.getTypeId(), score); } } } return topK; }
这一步就是一个简单的计数,统计K个样本中出现次数最多的分类,该分类就是我们要预测的目标数据的分类。
/** * @param value * @return * @Author:lulei * @Description: KNN分类判断value的类别 */ public String getTypeId(T value) { KnnValueSort[] array = getKType(value); HashMap<String, Integer> map = new HashMap<String, Integer>(K); for (KnnValueSort bean : array) { if (bean != null) { if (map.containsKey(bean.getTypeId())) { map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1); } else { map.put(bean.getTypeId(), 1); } } } String maxTypeId = null; int maxCount = 0; Iterator<Entry<String, Integer>> iter = map.entrySet().iterator(); while (iter.hasNext()) { Entry<String, Integer> entry = iter.next(); if (maxCount < entry.getValue()) { maxCount = entry.getValue(); maxTypeId = entry.getKey(); } } return maxTypeId; }
基类源码
/** *@Description: KNN分类 */ package com.lulei.datamining.knn; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map.Entry; import com.lulei.datamining.knn.bean.KnnValueBean; import com.lulei.datamining.knn.bean.KnnValueSort; import com.lulei.util.JsonUtil; @SuppressWarnings({"rawtypes"}) public abstract class KnnClassification<T> { private List<KnnValueBean> dataArray; private int K = 3; public int getK() { return K; } public void setK(int K) { if (K < 1) { throw new IllegalArgumentException("K must greater than 0"); } this.K = K; } /** * @param value * @param typeId * @Author:lulei * @Description: 向模型中添加记录 */ public void addRecord(T value, String typeId) { if (dataArray == null) { dataArray = new ArrayList<KnnValueBean>(); } dataArray.add(new KnnValueBean<T>(value, typeId)); } /** * @param value * @return * @Author:lulei * @Description: KNN分类判断value的类别 */ public String getTypeId(T value) { KnnValueSort[] array = getKType(value); System.out.println(JsonUtil.parseJson(array)); HashMap<String, Integer> map = new HashMap<String, Integer>(K); for (KnnValueSort bean : array) { if (bean != null) { if (map.containsKey(bean.getTypeId())) { map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1); } else { map.put(bean.getTypeId(), 1); } } } String maxTypeId = null; int maxCount = 0; Iterator<Entry<String, Integer>> iter = map.entrySet().iterator(); while (iter.hasNext()) { Entry<String, Integer> entry = iter.next(); if (maxCount < entry.getValue()) { maxCount = entry.getValue(); maxTypeId = entry.getKey(); } } return maxTypeId; } /** * @param value * @return * @Author:lulei * @Description: 获取距离最近的K个分类 */ private KnnValueSort[] getKType(T value) { int k = 0; KnnValueSort[] topK = new KnnValueSort[K]; for (KnnValueBean<T> bean : dataArray) { double score = similarScore(bean.getValue(), value); if (k == 0) { //数组中的记录个数为0是直接添加 topK[k] = new KnnValueSort(bean.getTypeId(), score); k++; } else { if (!(k == K && score < topK[k -1].getScore())){ int i = 0; //找到要插入的点 for (; i < k && score < topK[i].getScore(); i++); int j = k - 1; if (k < K) { j = k; k++; } for (; j > i; j--) { topK[j] = topK[j - 1]; } topK[i] = new KnnValueSort(bean.getTypeId(), score); } } } return topK; } /** * @param o1 * @param o2 * @return * @Author:lulei * @Description: o1 o2之间的相似度 */ public abstract double similarScore(T o1, T o2); }
具体子类实现
对于上面介绍的都在KNN分类的抽象基类中,对于实际的问题我们需要继承基类并实现基类中的相似度抽象方法,这里我们做一个简单的实现。
/** *@Description: */ package com.lulei.datamining.knn.test; import com.lulei.datamining.knn.KnnClassification; import com.lulei.util.JsonUtil; public class Test extends KnnClassification<Integer>{ @Override public double similarScore(Integer o1, Integer o2) { return -1 * Math.abs(o1 - o2); } /** * @param args * @Author:lulei * @Description: */ public static void main(String[] args) { Test test = new Test(); for (int i = 1; i < 10; i++) { test.addRecord(i, i > 5 ? "0" : "1"); } System.out.println(JsonUtil.parseJson(test.getTypeId(0))); } }
-------------------------------------------------------------------------------------------------
小福利
-------------------------------------------------------------------------------------------------
个人在极客学院上《Lucene案例开发》课程已经上线了,欢迎大家吐槽~
第一课:Lucene概述
第二课:Lucene 常用功能介绍
第三课:网络爬虫
第四课:数据库连接池
第五课:小说网站的采集
第六课:小说网站数据库操作
第七课:小说网站分布式爬虫的实现
第八课:Lucene实时搜索