KNN分类器-Java实现

KNN,即K近邻算法。其基本思想或者说是实现步骤如下: 

(1)计算样本数据点到每个已知类别的数据集中点的距离 

(2)将(1)中得到的距离按递增顺序排列 

(3)选取(2)中前K个点(即与当前样本距离最小的K个已知类别的数据点) 

(4)统计(3)中得到的K个点所在类别的出现频率 

(5)返回(4)中出现频率最高的类别作为样本点的预测类别 

在给出具体实现代码之前,说明一点:Java下的矩阵操作类基于开源jama包,我自己基于它的源码,做了部分必要的扩充和修改。 

具体实现代码如下: 

 /** 

 * Created by Song on 2016/9/30. 

*/

 public class KnnHandler implements DMHandler { 

 //训练集中,每个特征的最小值 

 private Matrix minVals; 

 //训练集中,每个特征的最大值 

 private Matrix maxVals; 

 //训练集中,每个特征的取值范围 

 private Matrix ranges; 


public KnnHandler(Matrix dataSet){ 

      double [][] minMax = dataSet.getMinMax(); 

      this.minVals = new Matrix(minMax[0],1); 

      this.maxVals = new Matrix(minMax[1],1); 

      this.ranges = maxVals.minus(minVals); 

 } 

 /**

 * 归一化特征值 

 * @param dataSet 特征集 

 */ 

  public Matrix autoNorm(Matrix dataSet){ 

       double[][] norm = dataSet.getArray(); 

       for(int j=0;j

            for(int i=0;i

                  norm[i][j] = (norm[i][j]-minVals.get(0,j))/ranges.get(0,j); 

            } 

       } 

       return new Matrix(norm); 

 } 

 /** 

 * K近邻算法 

 * @param sample 待评估样本 

 * @param dataSet 数据集 

 * @param labels 数据集中,每行数据对应的类别 

 * @param rate 将距离按由小至大排列,按比例选择固定数量的类别 

 */ 

 public double classify(Matrix sample,Matrix dataSet,Matrix labels,double rate){ 

       //统计样本频率 

      HashMap levels = new HashMap(); 

      //遍历类别,得出一共有几类 

     for(int i=0;i

           if(!levels.containsKey(labels.get(i,0))) levels.put(labels.get(i,0),0); 

     } 

     //获得距离,并递增排序 

    Matrix sortedDistance = sample.distance(dataSet).expand(labels,true).sort(); 

    //取前num个数据 

    int num = (int)Math.ceil(sortedDistance.getRowDimension()*rate); 

    for(int i=0;i

 } 

 //按频率排序 

 double targetLevel = 0; 

 int count = 0; 

 for(double key:levels.keySet()){ 

       if(levels.get(key)>count) { 

              count = levels.get(key); 

              targetLevel = key; 

           } 

 } 

 return targetLevel; 

 } 

 //测试

public static void main(String [] args){ 

//随机生成训练集(已知类别) 

Random random = new Random(); 

 double [][] dataSet = new double[100][4]; 

 for(int i=0;i<100;i++){ 

       for(int j=0;j<4;j++){ 

             dataSet[i][j]=random.nextInt(10); 

        } 

 } 

 //训练集中100组数据对应的类别 

 double [] lables = new double[100]; 

 for(int i=0;i<100;i++){ 

           lables[i]=i/10; 

 } 

 //生成待分类样本 

 double [] sample = {1,2,3,4}; 

 //KNN操作类实例化 

 KnnHandler handler = new KnnHandler(new Matrix(dataSet)); //handler.autoNorm(new Matrix(dataSet)).print(4,3); 

 //输出分类结果 

 System.out.println(handler.classify(new Matrix(sample,1),new Matrix(dataSet),new Matrix(lables,1).transpose(),0.3)); 

    } 

其中部分函数,例如构造器中获得数据集中每个特征的最小最大取值(即一个二维数组中每列值的最小最大值)方法getMinMax()等,都是自己基于jama源码扩充得到的,原理很简单,此处就不列出来了。 可以看出,KNN分类是一种非常基础的分类算法,适用于数值型数据。通过计算未知数据点到已知数据点的距离,来判断其具体分类。 


转载于:https://juejin.im/post/5cef4f326fb9a07ee30dfdad

你可能感兴趣的:(KNN分类器-Java实现)