Java实现简单版SVM

Java实现简单版SVM

最近的图像分类工作要用到latent svm,为了更加深入了解svm,自己动手实现一个简单版的。

        之所以说是简单版,因为没有用到拉格朗日,对偶,核函数等等。而是用最简单的梯度下降法求解。其中的数学原理我参考了http://blog.csdn.net/lifeitengup/article/details/10951655,文中是用matlab实现的svm。


源代码和数据集下载:https://github.com/linger2012/simpleSvm

其中数据集来自于libsvm,我找了其中一个数据集http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/breast-cancer_scale。
将她分成两部分,训练集和测试集,对应于train_bc和test_bc。

其中测试结果如下:
Java实现简单版SVM_第1张图片


package com.linger.svm;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.util.StringTokenizer;

public class SimpleSvm 
{
	private int exampleNum;
	private int exampleDim;
	private double[] w;
	private double lambda;
	private double lr = 0.001;//0.00001
	private double threshold = 0.001;
	private double cost;
	private double[] grad;
	private double[] yp;
	public SimpleSvm(double paramLambda)
	{

		lambda = paramLambda;	
		
	}
	
	private void CostAndGrad(double[][] X,double[] y)
	{
		cost =0;
		for(int m=0;m<exampleNum;m++)
		{
			yp[m]=0;
			for(int d=0;d<exampleDim;d++)
			{
				yp[m]+=X[m][d]*w[d];
			}
			
			if(y[m]*yp[m]-1<0)
			{
				cost += (1-y[m]*yp[m]);
			}
			
		}
		
		for(int d=0;d<exampleDim;d++)
		{
			cost += 0.5*lambda*w[d]*w[d];
		}
		

		for(int d=0;d<exampleDim;d++)
		{
			grad[d] = Math.abs(lambda*w[d]);	
			for(int m=0;m<exampleNum;m++)
			{
				if(y[m]*yp[m]-1<0)
				{
					grad[d]-= y[m]*X[m][d];
				}
			}
		}				
	}
	
	private void update()
	{
		for(int d=0;d<exampleDim;d++)
		{
			w[d] -= lr*grad[d];
		}
	}
	
	public void Train(double[][] X,double[] y,int maxIters)
	{
		exampleNum = X.length;
		if(exampleNum <=0) 
		{
			System.out.println("num of example <=0!");
			return;
		}
		exampleDim = X[0].length;
		w = new double[exampleDim];
		grad = new double[exampleDim];
		yp = new double[exampleNum];
		
		for(int iter=0;iter<maxIters;iter++)
		{
			
			CostAndGrad(X,y);
			System.out.println("cost:"+cost);
			if(cost< threshold)
			{
				break;
			}
			update();
			
		}
	}
	private int predict(double[] x)
	{
		double pre=0;
		for(int j=0;j<x.length;j++)
		{
			pre+=x[j]*w[j];
		}
		if(pre >=0)//这个阈值一般位于-1到1
			return 1;
		else return -1;
	}
	
	public void Test(double[][] testX,double[] testY)
	{
		int error=0;
		for(int i=0;i<testX.length;i++)
		{
			if(predict(testX[i]) != testY[i])
			{
				error++;
			}
		}
		System.out.println("total:"+testX.length);
		System.out.println("error:"+error);
		System.out.println("error rate:"+((double)error/testX.length));
		System.out.println("acc rate:"+((double)(testX.length-error)/testX.length));
	}
	
	
	
	public static void loadData(double[][]X,double[] y,String trainFile) throws IOException
	{
		
		File file = new File(trainFile);
		RandomAccessFile raf = new RandomAccessFile(file,"r");
		StringTokenizer tokenizer,tokenizer2; 

		int index=0;
		while(true)
		{
			String line = raf.readLine();
			
			if(line == null) break;
			tokenizer = new StringTokenizer(line," ");
			y[index] = Double.parseDouble(tokenizer.nextToken());
			//System.out.println(y[index]);
			while(tokenizer.hasMoreTokens())
			{
				tokenizer2 = new StringTokenizer(tokenizer.nextToken(),":");
				int k = Integer.parseInt(tokenizer2.nextToken());
				double v = Double.parseDouble(tokenizer2.nextToken());
				X[index][k] = v;
				//System.out.println(k);
				//System.out.println(v);				
			}	
			X[index][0] =1;
			index++;		
		}
	}
	
	public static void main(String[] args) throws IOException 
	{
		// TODO Auto-generated method stub
		double[] y = new double[400];
		double[][] X = new double[400][11];
		String trainFile = "E:\\project\\workspace\\Algorithms\\bin\\train_bc";
		loadData(X,y,trainFile);
		
		
		SimpleSvm svm = new SimpleSvm(0.0001);
		svm.Train(X,y,7000);
		
		double[] test_y = new double[283];
		double[][] test_X = new double[283][11];
		String testFile = "E:\\project\\workspace\\Algorithms\\bin\\test_bc";
		loadData(test_X,test_y,testFile);
		svm.Test(test_X, test_y);
		
	}

}



本文作者:linger
本文链接:http://blog.csdn.net/lingerlanlan/article/details/38688539


你可能感兴趣的:(java,数据挖掘,机器学习,SVM,分类器)