1、算法基本原理:
X<-matrix(1:50,nrow=25,ncol=2)
Y<-matrix(0,nrow=25,ncol=1)
for (i in 1:25){
if (runif(1)<0.5){
X[i,1]=exp(runif(1))*1.3
X[i,2]=exp(runif(1))*3
Y[i]=0
}
else{
X[i,1]=exp(runif(1))*3
X[i,2]=exp(runif(1))*3
Y[i]=1
}
}
data<-cbind(Y,X)
write.table(data,"C:/Users/CJH/Desktop/R程序运行/KNNtest.txt",row.names=FALSEcol.names
=,FALSE)
#数据生成
KNNtrain<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNtrain.txt",head=FALSE)
KNNtest<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNtest.txt",head=FALSE)
class1<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer1.txt",head=FALSE)
class2<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer2.txt",head=FALSE)
class3<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer3.txt",head=FALSE)
class4<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer4.txt",head=FALSE)
class5<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer5.txt",head=FALSE)
class6<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer6.txt",head=FALSE)
class10<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer10.txt",head=FALSE)
class30<-read.table("C:/Users/CJH/Desktop/R程序运行/KNNanswer30.txt",head=FALSE)
#输入数据及结果
KNNtrain<-data.frame(KNNtrain)
names(KNNtrain)<-c("class","x","y")
KNNtrain$class<-factor(KNNtrain$class)
library(ggplot2)
ggplot(data=KNNtrain,aes(x=KNNtrain$x,y=KNNtrain$y,shape=KNNtrain$class,color=KNNtrain$class))+
geom_point(size=3)+labs(title="TrainData",x="x",y="y")
KNNtest<-data.frame(KNNtest)
names(KNNtest)<-c("class","x","y")
KNNtest$class<-factor(KNNtest$class)
ggplot(data=KNNtest,aes(x=KNNtest$x,y=KNNtest$y,shape=KNNtest$class,color=KNNtest$class))+
geom_point(size=3)+labs(title="TestData",x="x",y="y")
#生成图像
Eclipse程序:
package KNN;
import java.io.*;
import java.util.*;
public class InputData{
public void loadData(double [][]x,double[]y,String trainfile)throws IOException{
File file = new File("C:\\Users\\CJH\\Desktop\\R程序运行",trainfile);
RandomAccessFile raf= new RandomAccessFile(file,"r");
StringTokenizer tokenizer;
int i=0,j=0;
while(true){
String line = raf.readLine();
if(line==null)break;
tokenizer= new StringTokenizer(line);
y[i]=Double.parseDouble(tokenizer.nextToken());
while(tokenizer.hasMoreTokens()){
x[i][j]=Double.parseDouble(tokenizer.nextToken());
j++;
}
j=0;i++;
}
raf.close();
}
}
//输入数据
package KNN;
public class CrossValidation {
private int k;
private int n;
private int m;
private int n1;
public int getK(){
return k;
}
public int getN(){
return n;
}
public int getM(){
return m;
}
public int getN1(){
return n1;
}
public void setK(int b,int a,int t,int p){
k=b;
n=a;
m=t;
n1=p;
}
}
//原始参数设置
package KNN;
public class KNN {
public double[] y(double[][] X,double[] Y,double[][] newpoints,int k,double[] c){
double[][] distance=new double[newpoints.length][X.length];
double[] y=new double[newpoints.length];
int[][] rank=new int[newpoints.length][k];
for (int i=0;ifor(int j=0;jfor(int i=0;i0]];
int t1=count(rank[i],Y,Y[rank[i][0]]);
for(int j=0;jint t=count(rank[i],Y,c[j]);
if(t>t1){
y[i]=c[j];
}
}
}
return y;
}
private int count(int[] rank, double[] y, double d) {
int count=0;
for (int i=0;iif(y[rank[i]]==d){
count=count+1;
}
}
return count;
}
private int[] Rank(double[] distance,int k) {
int[] Rank=new int[k];
int[] temp1= new int[distance.length];
for(int i=0;ifor (int i=0;idouble temp=distance[i];
int temp2=temp1[i];
for(int j=i+1;jif(temp>distance[j]){
temp2=temp1[j];
temp1[j]=temp1[i];
temp1[i]=temp2;
temp=distance[j];
}
if(temp1[i]!=i){
distance[j]=distance[i];
distance[i]=temp;
}
}
}
for (int i=0;ireturn Rank;
}
private double euclid(double[] ds, double[] ds2) {
double distance=0;
if (ds.length==ds2.length){
for(int i=0;i2);
}
distance=Math.sqrt(distance);
}
else{
distance=0;
}
return distance;
}
}
//KNN算法
package KNN;
import java.io.IOException;
public class KNNmain{
public static void main(String[] args) throws IOException {
CrossValidation kvalue=new CrossValidation();
kvalue.setK(30,100,2,25);
int k=kvalue.getK();
int n=kvalue.getN();
int m=kvalue.getM();
int n1=kvalue.getN1();
double[] y=new double[n1];
double[] Y1=new double[n1];
double[][] newpoints=new double[n1][m];
InputData ori=new InputData();
InputData op=new InputData();
double[][] X=new double[n][m];
double[] Y=new double[n];
ori.loadData(X, Y, "KNNtrain.txt");
op.loadData(newpoints,Y1,"KNNtest.txt");
KNN Kdata=new KNN();
double[] c=new double[2];
c[0]=0;
c[1]=1;
y=Kdata.y(X,Y,newpoints,k,c);
double temp=0;
for (int i=0;iif(y[i]!=Y1[i]){
temp=temp+1;
}
System.out.println(y[i]+" ");
}
System.out.println(temp/(double)newpoints.length);
}
}
//主程序,输出分类结果及交叉验证结果
k值 | k=1 | k=2 | k=3 | k=4 | k=5 | k=6 | k=10 | k=30 |
---|---|---|---|---|---|---|---|---|
训练集错误率 | 0.04 | 0.04 | 0.6 | 0.4 | 0.56 | 0.24 | 0.24 | 0.30 |
由上表,该数据中K值取1或者2是最好的。