【十大算法实现之KNN】KNN算法实例(含测试数据和源码)

KNN算法基本的思路是比较好理解的,今天根据它的特点写了一个实例,我会把所有的数据和代码都写在下面供大家参考,不足之处,请指正。谢谢!

 

工程代码和测试数据下载:http://pan.baidu.com/s/1kThzcwR

 

几点说明:

1.KNN中的K=5;

2.在计算权重时,采用的是减去函数{1,0.8,0.6,0.4,0.2},当然你也可以采用反函数或高斯函数;

3.5%作为测试集(decision.txt),95%作为训练集(training.txt);

4.在计算costfun之前,对所有的属性进行了归一化,由于这里不知道数据集每个属性代表的含义,所以就一视同仁,实际情况下,应该具体问题具体分析;

 

image

 

 

XBWKNN.java

package XBWKNN;



import java.io.IOException;

import java.util.ArrayList;

import java.util.Collections;

import java.util.Comparator;

import java.util.List;



/**

 * KNN算法

 * @author XBW

 * @date 2014年8月16日

 */





public class XBWKNN{

    public final static int KofKNN=5;

    public final static double weight[]={1,0.9,0.7,0.4,0.1};                //减法函数y=1-0.2*x

    



    /**

     * knn

     * @param data

     * @param ds

     * @return ans

     */

    public static int knn(Data data,DataSet ds){

        int ans = 0;

        List<Data> dis=calcDis(data,ds);

        ans=calcKDis(data,dis);

        return ans;

    }

    

    /**

     * 计算训练集中所有向量的距离,排序之后取前K个

     * @param data

     * @param ds

     * @return

     */

    @SuppressWarnings("null")

    public static List<Data>calcDis(Data data,DataSet ds){

        List<Data> anslist =new ArrayList<Data>();

        double dx1=data.x1;

        double dx2=data.x2;

        double dx3=data.x3;

        for(int i=0;i<ds.ds.size();i++){

            double x1=ds.ds.get(i).x1;

            double x2=ds.ds.get(i).x2;

            double x3=ds.ds.get(i).x3;

            ds.ds.get(i).costfun=Math.sqrt((dx1-x1)*(dx1-x1)+(dx2-x2)*(dx2-x2)+(dx3-x3)*(dx3-x3));

            anslist.add(ds.ds.get(i));

        }

        Collections.sort(anslist,new Comparator<Data>(){

               public int compare(Data o1, Data o2) {

                   Double s=o1.costfun-o2.costfun;

                   if(s<0)

                       return -1;

                   else

                       return 1; 

                }

        });

        return anslist;

    }

    

    

    /**

     * 按一定的权重计算出前K个

     * @param data

     * @param ds

     * @return

     */

    public static int calcKDis(Data data,List<Data> anslist){

        Double[] anstype={0.0,0.0,0.0,0.0};

        for(int i=0;i<KofKNN;i++){

            if(anslist.get(i).type==1){

                anstype[1]+=weight[i];

            }

            else if(anslist.get(i).type==2){

                anstype[2]+=weight[i];

            }

            if(anslist.get(i).type==3){

                anstype[3]+=weight[i];

            }

        }

        Double maxt=-1.0;

        int tag=1;

        for(int i=1;i<=3;i++){

            if(maxt<anstype[i]){

                tag=i;

                maxt=anstype[i];

            }

        }

        return tag;

    }

    

    public static void main(String[] args) throws IOException{

        DataSet ds=new DataSet();

        DataTest dt=new DataTest();

        

        int correct=0;

        for(int i=0;i<dt.dt.size();i++){

            Data data=dt.dt.get(i);

            int result=knn(data,ds);

            if(result==data.type){

                correct++;

            }

        }

        System.out.println("total test num :"+dt.dt.size());

        System.out.println("correct test num :"+correct);

        System.out.println("ratio :"+correct/(double)dt.dt.size());

    }

}

 

Datatest.java

package XBWKNN;



import java.io.BufferedReader;

import java.io.File;

import java.io.FileReader;

import java.io.IOException;

import java.util.ArrayList;

import java.util.List;









/**

 * 测试数据

 * @author XBW

 * @date 2014年8月16日

 */



public class DataTest{

    String defaultpath="D:\\MachineLearning\\十大算法\\KNN\\knncode\\decision.txt";

    List<Data> dt;

    

    @SuppressWarnings("null")

    public DataTest() throws IOException{

        List<Data> dset = new ArrayList<Data>();

        File ds=new File(defaultpath);

        @SuppressWarnings("resource")

        BufferedReader br = new BufferedReader(new FileReader(ds));

        String tsing;

        double max1=-1;

        double max2=-1;

        double max3=-1;

        while((tsing=br.readLine())!=null){

            String[] dlist=tsing.split("    ");

            Data data=new Data();

            data.x1=Double.parseDouble(dlist[0]);

            data.x2=Double.parseDouble(dlist[1]);

            data.x3=Double.parseDouble(dlist[2]);

            data.type=Integer.parseInt(dlist[3]);

            dset.add(data);

            

            if(data.x1>max1){

                max1=data.x1;

            }

            if(data.x2>max2){

                max2=data.x2;

            }

            if(data.x3>max3){

                max3=data.x3;

            }

        }

        dset=normalization(dset,max1,max2,max3);

        this.dt=dset;

    }

    

    public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){

        for(int i=0;i<dset.size();i++){

            dset.get(i).x1/=m1;

            dset.get(i).x2/=m2;

            dset.get(i).x3/=m3;

        }

        return dset;

    }

}

 

DataSet.java

package XBWKNN;



import java.io.BufferedReader;

import java.io.File;

import java.io.FileReader;

import java.io.IOException;

import java.util.ArrayList;

import java.util.List;









/**

 * 训练数据

 * @author XBW

 * @date 2014年8月16日

 */



public class DataSet{

    String defaultpath="D:\\MachineLearning\\十大算法\\KNN\\knncode\\training.txt";

    List<Data> ds;

    

    @SuppressWarnings("null")

    public DataSet() throws IOException{

        List<Data> dset =new ArrayList<Data>();

        File ds=new File(defaultpath);

        @SuppressWarnings("resource")

        BufferedReader br = new BufferedReader(new FileReader(ds));

        String tsing;

        double max1=-1;

        double max2=-1;

        double max3=-1;

        while((tsing=br.readLine())!=null){

            String[] dlist=tsing.split("    ");

            Data data=new Data();

            data.x1=Double.parseDouble(dlist[0]);

            data.x2=Double.parseDouble(dlist[1]);

            data.x3=Double.parseDouble(dlist[2]);

            data.type=Integer.parseInt(dlist[3]);

            dset.add(data);

            

            if(data.x1>max1){

                max1=data.x1;

            }

            if(data.x2>max2){

                max2=data.x2;

            }

            if(data.x3>max3){

                max3=data.x3;

            }

        }

        dset=normalization(dset,max1,max2,max3);

        this.ds=dset;

    }

    

    public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){

        for(int i=0;i<dset.size();i++){

            dset.get(i).x1/=m1;

            dset.get(i).x2/=m2;

            dset.get(i).x3/=m3;

        }

        return dset;

    }

}

 

Data.java

package XBWKNN;



/**

 * 一条数据

 * @author XBW

 * @date 2014年8月16日

 */



public class Data{

    Double x1;

    Double x2;

    Double x3;

    Double costfun;

    int type;

}

 

output:

image

你可能感兴趣的:(算法)