package org.lw.fenlei;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
public class Test {
@SuppressWarnings("resource")
public static void main(String[] args) throws Exception{
// 定义训练集点a{10.0, 10.0} 和 点b{-10.0, -10.0},对应label为{1.0, -1.0}
// svm_node pa0 = new svm_node();
// pa0.index = 0;
// pa0.value = 10.0;
// svm_node pa1 = new svm_node();
// pa1.index = -1;
// pa1.value = 10.0;
// svm_node pb0 = new svm_node();
// pb0.index = 0;
// pb0.value = -10.0;
// svm_node pb1 = new svm_node();
// pb1.index = 0;
// pb1.value = -10.0;
// svm_node[] pa = {pa0,pa1};//点a
// svm_node[] pb = {pb0,pb1};//点b
// svm_node[][] datas = {pa,pb};//训练集的向量表
// double[] lables = {1.0,-1.0};//a,b对应的lable
//找到字典的长度
int ZDlength=0;
File ZD = new File ("c:" + File.separator + "dm" + File.separator + "ZD.txt");
BufferedReader ZDbuf = null;
InputStream ZDinput = new FileInputStream(ZD);
ZDbuf = new BufferedReader(new InputStreamReader(ZDinput));
while(ZDbuf.readLine()!=null){
ZDlength++;
}
System.out.println("字典长度为:"+ZDlength);
//找到xunlian文件夹下有几个文件,有几个文件就是分几个类
File file = new File ("c:" + File.separator + "dm" + File.separator + "xunlian1");
String path[] = file.list();
int lengths = path.length ;
svm_node[][] datas = new svm_node[10000][ZDlength];
for(int j=0; j<10000; j ++)
{
for(int i = 0; i < ZDlength; i ++) {
datas[j][i] =new svm_node();
datas[j][i].value = 0.0;
if(i!=(datas[j].length-1))
datas[j][i].index = i+1;
else
datas[j][ZDlength-1].index=-1;
}
}
double[] lables = new double[1000];
int vectors=0;
int lb = 0;
double bq = 1.0;
//分别读入每个文件中的向量
for (int i = 0;i
System.out.println(path[i]+" "+bq);
BufferedReader buf = null;
File f = new File ("c:" + File.separator + "dm" + File.separator + "xunlian1"+ File.separator + path[i]);
InputStream input = new FileInputStream(f);
buf = new BufferedReader(new InputStreamReader(input));
String b;
while((b = buf.readLine())!=null){
vectors++;
lables[lb] = bq;
//System.out.println(b);
int a;
String[] temp = b.split(" ");
if(temp.length==0)
a =1;
double[] result = new double[temp.length];
int k;
for(k=0; k
String vector[] = temp[k].split(",");
int index = Integer.parseInt(vector[0]);
Double value = Double.parseDouble(vector[1]);
result[k] = value;
if(index!=-1){
datas[lb][index-1].value = result[k];
// System.out.println("第"+lb+"篇文档第"+index+"个值为:"+datas[lb][index-1].value);
}
else{
datas[lb][ZDlength-1].value = result[k];
//System.out.println("第"+lb+"篇文档第"+index+"个值为:"+datas[lb][ZDlength-1].value);
}
}
catch(Exception e){
}
}
lb++;
}
bq++;
buf.close();
}
//定义svm_problem对象
svm_problem problem = new svm_problem();
problem.l = vectors;//向量个数
problem.x = datas;//训练集向量表
problem.y = lables;//对应的lable数组
//定义svm_parameter对象
svm_parameter param = new svm_parameter();
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.LINEAR;
param.cache_size = 10000;
param.eps = 0.00001;
param.C = 1;
//训练SVM分类模型
System.out.println(svm.svm_check_parameter(problem, param));//如果参数
//没有问题,则该函数返回null,否则返回error描述。
svm_model model = svm.svm_train(problem,param);//svm.svm_train()训练出SVM分类模型;
//定义测试数据点c
//svm_node pc0 = new svm_node();
//pc0.index = 0;
//pc0.value = -0.1;
//svm_node pc1 = new svm_node();
//pc1.index = 1;
//pc1.value = -0.0;
//svm_node pc2 = new svm_node();
//pc2.index = 2;
//pc2.value = -6.0;
//svm_node pc3 = new svm_node();
//pc3.index = 3;
//pc3.value = -4.0;
//svm_node pc4 = new svm_node();
//pc4.index = 4;
//pc4.value = -3.0;
//svm_node pc5 = new svm_node();
//pc1.index = -1;
//pc1.value = 0.0;
//svm_node[] pc = {pc0,pc1,pc2,pc3,pc4,pc5};
//输入要分类的文本数据
//File file1 = new File ("c:" + File.separator + "dm" + File.separator + "input.txt");
//InputStream input1 = new FileInputStream(file1);
//svm_node[] test = new svm_node[1000];
//for(int n = 0;n<1000;n++)
//test[n] = new svm_node();
//BufferedReader buf1 = null;
//buf1 = new BufferedReader(new InputStreamReader(input1));
//String b1;
//while((b1 = buf1.readLine())!=null){
//String[] temp1 = b1.split(" ");
//@SuppressWarnings("unused")
//double[] result1 = new double[temp1.length];
//for(int k1=0; k1
//if(k1!=(temp1.length-1))
//test[k1].index = k1;
//else
//test[k1].index = -1;
//test[k1].value = result1[k1];
//}
//}
//预测测试数据的lable
//System.out.println(svm.svm_predict(model, test));
//输入测试集,并计算准确率和召回率
File file1 = new File ("c:" + File.separator + "dm" + File.separator + "ceshi");
String path1[] = file1.list();
int lengths1 = path1.length ;
svm_node[][] test = new svm_node[10000][ZDlength];
for(int j=0; j<1000; j ++)
{
for(int i = 0; i < ZDlength; i ++) {
test[j][i] =new svm_node();
test[j][i].value = 0.0;
if(i!=(ZDlength-1))
test[j][i].index = i+1;
else
test[j][ZDlength-1].index = -1;
}
}
int lb1=0;
double cc=0.0,bq1=1.0,sumzql=0.0,sumzhl=0.0,zzql=0.0, zzhl=0.0;
double[] zq = new double[lengths1]; //正确分得该类的篇数
double[] zp = new double[lengths1]; //总分得该类的篇数
double[] bp = new double[lengths1]; //测试集中该类本来有多少篇
double[] zql = new double[lengths1]; //每个类的准确率
double[] zhl = new double[lengths1]; //每个类的召回率
for(int i=0;i
}
for (int i = 0;i
bp[i]=0.0;
zq[i]=0.0;
System.out.println(path1[i]);
BufferedReader buf1 = null;
File f1 = new File ("c:" + File.separator + "dm" + File.separator + "ceshi"+ File.separator + path1[i]);
InputStream input1 = new FileInputStream(f1);
buf1 = new BufferedReader(new InputStreamReader(input1));
String b1;
while((b1 = buf1.readLine())!=null){
String[] temp1 = b1.split(" ");
@SuppressWarnings("unused")
double[] result1 = new double[temp1.length];
for(int k1=0; k1
String vector[] = temp1[k1].split(",");
int index = Integer.parseInt(vector[0]);
Double value = Double.parseDouble(vector[1]);
result1[k1] = value;
if(index!=-1)
test[lb1][index-1].value = result1[k1];
else
test[lb1][ZDlength-1].value = result1[k1];
}
cc= svm.svm_predict(model, test[lb1]);
System.out.println("第"+(lb1+1)+"篇文档所分得的的类别是"+cc);
lb1++;
if(cc==bq1){
zq[i]++;
zp[i]++;
}
else{
zp[((int) cc)-1]++;
}
bp[i]++;
}
bq1++;
}
double j =1.0;
for(int i=0;i
System.out.println("第"+j+"类文档的准确率是"+zql[i]);
zhl[i] = zq[i]/bp[i];
System.out.println("第"+j+"类文档的召回率是"+zhl[i]);
sumzql+= zql[i];
sumzhl+=zhl[i];
j=j+1.0;
}
zzql=sumzql/lengths1;
System.out.println("总准确率是"+zzql);
zzhl=sumzhl/lengths1;
System.out.println("总召回率是"+zzhl);
}
}