机器学习入门算法及其java实现-KNN算法

1、算法基本原理:

  • 对于一个新点 X0(x0,y0) ,它的分类 y0 由离它最近的k个点的类别决定;
  • 其中训练集为 T{(x1,y1),(x2,y2),...,(xn,yn)} ,离 X0(x0,y0) 最近的K个点根据分类决策规则(如多数表决)决定 X0(x0,y0) 的类别 y0 :
    y0=argmaxξjxjNk(x)I(yi=cj),i=1,2,...,N;j=1,2,...,K

    2、距离度量:
    特征空间中两个实例点的距离是两个实例点相似程度的反映。K近邻模型的特征空间一般是n维实数向量空间 Rn 。使用的距离是欧氏距离或其他距离,如更一般的 Lp 距离和Minkowski距离。
    设特征空间 χ 是n维实数向量空间 Rn , xi , xj χ , xi=(x(1)i,x(2)i,...,x(n)i)T , xj=(x(1)j,x(2)j,...,x(n)j)T , xi , xj Lp 的距离定义为:
    Lp(xi,xj)=(l=1n|xlixlj|p)1p
    这里 p1 ,当 p=2 时,称为欧式距离,即
    L2(xi,xj)=(l=1n|xlixlj|2)
    p=1 ,称为曼哈顿距离,即:
    L1(xi,xj)=l=1n|xlixlj|2

    p= 时,它时各个坐标距离的最大值,即
    L(xi,xj)=maxl|xlixli|

    3、K值的选择:
    K值的选择会对K近邻法产生重大的影响。
    如果选择较小的K值,相当于用较小的领域中的训练实例进行预测,“学习”的近似误差会减小,但学习的估计误差会增大,预测结果会对近邻的实例点非常敏感。
    如果选择较大的K值,可以减少估计误差,但是学习的近似误差会增大。
    在应用中,一般取一个较小的值,采用交叉验证的办法取最优的k值。
    4、分类决策规则:
    K近邻法中分类决策往往是多数表决,即由输入实例的K个近邻的训练实例中的多数类决定输入实例的类。
    使用平台:eclipse,R
    实验数据:人工数据
    相关程序:
    使用R生成数据,使用Java处理数据:
X<-matrix(1:50,nrow=25,ncol=2)
Y<-matrix(0,nrow=25,ncol=1)
for (i in 1:25){
   if (runif(1)<0.5){
   X[i,1]=exp(runif(1))*1.3  
   X[i,2]=exp(runif(1))*3
   Y[i]=0
}
   else{
   X[i,1]=exp(runif(1))*3
   X[i,2]=exp(runif(1))*3
   Y[i]=1
}
}
data<-cbind(Y,X)
write.table(data,"C:/Users/CJH/Desktop/R程序运行/KNNtest.txt",row.names=FALSEcol.names
=,FALSE)
#数据生成

KNNtrain<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNtrain.txt",head=FALSE)
KNNtest<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNtest.txt",head=FALSE)
class1<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer1.txt",head=FALSE)
class2<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer2.txt",head=FALSE)
class3<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer3.txt",head=FALSE)
class4<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer4.txt",head=FALSE)
class5<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer5.txt",head=FALSE)
class6<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer6.txt",head=FALSE)
class10<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer10.txt",head=FALSE)
class30<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer30.txt",head=FALSE)
#输入数据及结果

KNNtrain<-data.frame(KNNtrain)
names(KNNtrain)<-c("class","x","y")
KNNtrain$class<-factor(KNNtrain$class)
library(ggplot2)
ggplot(data=KNNtrain,aes(x=KNNtrain$x,y=KNNtrain$y,shape=KNNtrain$class,color=KNNtrain$class))+
geom_point(size=3)+labs(title="TrainData",x="x",y="y")

KNNtest<-data.frame(KNNtest)
names(KNNtest)<-c("class","x","y")
KNNtest$class<-factor(KNNtest$class)
ggplot(data=KNNtest,aes(x=KNNtest$x,y=KNNtest$y,shape=KNNtest$class,color=KNNtest$class))+
geom_point(size=3)+labs(title="TestData",x="x",y="y")
#生成图像

Eclipse程序:
package KNN;

import java.io.*;
import java.util.*;

public class InputData{
    public void loadData(double [][]x,double[]y,String trainfile)throws IOException{
       File file = new File("C:\\Users\\CJH\\Desktop\\R程序运行",trainfile);
       RandomAccessFile raf= new RandomAccessFile(file,"r");
       StringTokenizer tokenizer;   
       int i=0,j=0;
       while(true){
           String line = raf.readLine();
           if(line==null)break;
           tokenizer= new StringTokenizer(line);
           y[i]=Double.parseDouble(tokenizer.nextToken());
           while(tokenizer.hasMoreTokens()){
           x[i][j]=Double.parseDouble(tokenizer.nextToken());
           j++;
           }
           j=0;i++;
       }
       raf.close();
    }

}
//输入数据

package KNN;

public class CrossValidation {
    private int k;
    private int n;
    private int m;
    private int n1;
    public int getK(){
        return k;
        }
    public int getN(){
        return n;
        }
    public int getM(){
        return m;
        }
    public int getN1(){
        return n1;
        }
    public void setK(int b,int a,int t,int p){
         k=b;
         n=a;
         m=t;
         n1=p;
        }
}
//原始参数设置

package KNN;

public class KNN {

public double[] y(double[][] X,double[] Y,double[][] newpoints,int k,double[] c){
    double[][] distance=new double[newpoints.length][X.length];
    double[] y=new double[newpoints.length];
    int[][] rank=new int[newpoints.length][k];
    for (int i=0;ifor(int j=0;jfor(int i=0;i0]];
        int t1=count(rank[i],Y,Y[rank[i][0]]);
        for(int j=0;jint t=count(rank[i],Y,c[j]);
            if(t>t1){
                y[i]=c[j];
            }
        }
    }
    return y;
}

private int count(int[] rank, double[] y, double d) {
    int count=0;
    for (int i=0;iif(y[rank[i]]==d){
            count=count+1;
        }
    }
    return count;
}

private int[] Rank(double[] distance,int k) {
    int[] Rank=new int[k];
    int[] temp1= new int[distance.length];
    for(int i=0;ifor (int i=0;idouble temp=distance[i];
               int temp2=temp1[i];
               for(int j=i+1;jif(temp>distance[j]){
                       temp2=temp1[j];
                       temp1[j]=temp1[i];
                       temp1[i]=temp2;
                       temp=distance[j];
                   }
               if(temp1[i]!=i){
                   distance[j]=distance[i];
                   distance[i]=temp;
               }
        }
    }
               for (int i=0;ireturn Rank;
}

private double euclid(double[] ds, double[] ds2) {
    double distance=0;
    if (ds.length==ds2.length){
        for(int i=0;i2);
        }
        distance=Math.sqrt(distance);
    }
    else{
        distance=0;
    }
    return distance;
}
}
//KNN算法


package KNN;

import java.io.IOException;

public class KNNmain{
    public static void main(String[] args) throws IOException {
        CrossValidation kvalue=new CrossValidation();
        kvalue.setK(30,100,2,25);
        int k=kvalue.getK();
        int n=kvalue.getN();
        int m=kvalue.getM();
        int n1=kvalue.getN1();
        double[] y=new double[n1];
        double[] Y1=new double[n1];
        double[][] newpoints=new double[n1][m];
        InputData ori=new InputData();
        InputData op=new InputData();
        double[][] X=new double[n][m];
        double[] Y=new double[n];
        ori.loadData(X, Y, "KNNtrain.txt");
        op.loadData(newpoints,Y1,"KNNtest.txt");
        KNN Kdata=new KNN();
        double[] c=new double[2];
        c[0]=0;
        c[1]=1;
        y=Kdata.y(X,Y,newpoints,k,c);
        double temp=0;
        for (int i=0;iif(y[i]!=Y1[i]){
                temp=temp+1;
            }
            System.out.println(y[i]+" ");
        }
        System.out.println(temp/(double)newpoints.length);
    }
}
//主程序,输出分类结果及交叉验证结果

仿真数据:
机器学习入门算法及其java实现-KNN算法_第1张图片

训练数据

机器学习入门算法及其java实现-KNN算法_第2张图片

测试数据
k值 k=1 k=2 k=3 k=4 k=5 k=6 k=10 k=30
训练集错误率 0.04 0.04 0.6 0.4 0.56 0.24 0.24 0.30

由上表,该数据中K值取1或者2是最好的。

你可能感兴趣的:(机器学习十大算法,分类算法)