libsvm库的作者主页 :http://www.csie.ntu.edu.tw/~cjlin
好吧,都是英文的......
libsvm的分链接:
http://www.csie.ntu.edu.tw/~cjlin/libsvm/index.html
到通过分连接里面找到download,下载下libsvm的压缩包,其目录如下图:
好多东西的感脚,其实如果是java的话只要java那个文件夹里面的东西就口以啦
下面是java文件夹下面的动动,也很多的感脚,其实也只要libsvm.jar这个包就可以啦,超喜欢,一个包导进到工程项目中就可以直接用了,SO 方便.
对于其它的java文件要干嘛用呢,你可以当作例子,那应该都是作者写的实例吧,里面都是调用libsvm.jar包的类,看那些代码
有助于清晰明了的学会libsvm.jar的各种类的使用方法,文章后部分的代码参考svm_train,额,也不算参考吧,应为基本都是
那个文件里面的代码哈。
下面说下主要要用到的类有3个:
1.svm_parameter:用来保存svm的一些设置参数
2.svm_problem:用来保存样本的
3.svm:主要用来做svm分类的
其实还有svm_model类也是很重要的,只是暂时没用到,因为一开始我的目的只是为了能够用这个让lissvm成功运行起来,我就哈皮啦,所以这个还没研究。
不过后面要更灵活地使用svm就需要这个东西咯, mark先:svm_model ,保存分类模型的,具体用法还咩研究。
下面直接贴代码了,里面有关于参数设置以及如何将文件里的数据存到svm_problem的实例中,以及交叉验证的代码,摘抄自svm_train.java
package loma; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.StringTokenizer; import java.util.Vector; import libsvm.svm; import libsvm.svm_node; import libsvm.svm_parameter; import libsvm.svm_problem; public class test{ private static double atof(String s) { double d = Double.valueOf(s).doubleValue(); if (Double.isNaN(d) || Double.isInfinite(d)) { System.err.print("NaN or Infinity in input\n"); System.exit(1); } return(d); } //获取参数 private svm_parameter getParameter(){ svm_parameter param = new svm_parameter(); // default values param.svm_type = svm_parameter.C_SVC; param.kernel_type = svm_parameter.RBF; param.degree = 3; param.gamma = 0; // 1/num_features param.coef0 = 0; param.nu = 0.5; param.cache_size = 100; param.C = 1; param.eps = 1e-3; param.p = 0.1; param.shrinking = 1; param.probability = 0; param.nr_weight = 0; param.weight_label = new int[0]; param.weight = new double[0]; return param; } private static int atoi(String s) { return Integer.parseInt(s); } //获取问题描述 private svm_problem read_problem(String input_file_name,svm_parameter param) throws IOException { BufferedReader fp = new BufferedReader(new FileReader(input_file_name)); Vector<Double> vy = new Vector<Double>(); Vector<svm_node[]> vx = new Vector<svm_node[]>(); int max_index = 0; while(true) { String line = fp.readLine(); if(line == null) break; StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); vy.addElement(atof(st.nextToken())); int m = st.countTokens()/2; svm_node[] x = new svm_node[m]; for(int j=0;j<m;j++) { x[j] = new svm_node(); x[j].index = atoi(st.nextToken()); x[j].value = atof(st.nextToken()); } if(m>0) max_index = Math.max(max_index, x[m-1].index); vx.addElement(x); } svm_problem prob = new svm_problem(); prob.l = vy.size(); prob.x = new svm_node[prob.l][]; for(int i=0;i<prob.l;i++) prob.x[i] = vx.elementAt(i); prob.y = new double[prob.l]; for(int i=0;i<prob.l;i++) prob.y[i] = vy.elementAt(i); if(param.gamma == 0 && max_index > 0) param.gamma = 1.0/max_index; if(param.kernel_type == svm_parameter.PRECOMPUTED) for(int i=0;i<prob.l;i++) { if (prob.x[i][0].index != 0) { System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n"); System.exit(1); } if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) { System.err.print("Wrong input format: sample_serial_number out of range\n"); System.exit(1); } } fp.close(); return prob; } //交叉验证 private void do_cross_validation(svm_problem prob,svm_parameter param,int nr_fold) { int i; int total_correct = 0; double total_error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; double[] target = new double[prob.l]; svm.svm_cross_validation(prob,param,nr_fold,target); if(param.svm_type == svm_parameter.EPSILON_SVR || param.svm_type == svm_parameter.NU_SVR) { for(i=0;i<prob.l;i++) { double y = prob.y[i]; double v = target[i]; total_error += (v-y)*(v-y); sumv += v; sumy += y; sumvv += v*v; sumyy += y*y; sumvy += v*y; } System.out.print("Cross Validation Mean squared error = "+total_error/prob.l+"\n"); System.out.print("Cross Validation Squared correlation coefficient = "+ ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/ ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"\n" ); } else { for(i=0;i<prob.l;i++) if(target[i] == prob.y[i]) ++total_correct; System.out.print("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%\n"); } } public static void main(String[] args){ test t = new test(); svm_parameter param = t.getParameter(); svm_problem prob = null; try{ prob = t.read_problem("d:\\iris.scale",param); }catch(IOException e){ //todo } t.do_cross_validation(prob, param, 10); } }
整体步骤如下:
1.到官网下载,libsvm的压缩包。
2.从压缩包里面获取libsvm.jar并添加到java工程中。
3.创建一个简单的测试类,可直接将上面的代码复制粘贴过去体验下。
上述代码的基本思路:
首先要实例化svm_parameter,然后设置相应的参数;
其次读取数据文件,将里面的内容按要求存放到svm_problem对象中;
最后,通过调用svm.svm_cross_validation(prob,param,nr_fold,target);进行交叉验证
4.从http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ 下载数据集进行测试,具体的数据格式可以模仿这边下载下的文件
大概格式就是这样的 label index1:value1 index2:value2....
5.over