一个偷偷写的svm库

今早刚接触一个新的库——dlib(http://dlib.net),讲真,真的很好用。按照官方的介绍,就是:These wrappers provide a portable object oriented interface for networking, multithreading, GUI development, and file browsing. Programs written using them can be compiled under POSIX or MS Windows platforms without changing the code.也就是说,DLIB是一个C ++库,用于开发可移植的应用程序与网络处理,线程,图形界面,数据结构,线性代数,机器学习,XML和文本解析,数值优化,贝叶斯网,和许多其他任务。几乎涉及到数据分析的方方面面了。更重要的是,类似于openCV,它提供很多很详细的example,因此学习起来应该不难。由于今天第一天接触,就根据svm分类的example把它改写成了一个二类分类库,多类分类器以后再慢慢加进去。不过,功能应该不太完善。总之,先放上来吧,以后再慢慢改,目前是涉及到nu和C参数的调整,默认是对ganmma和C调参,因为这两个对结果影响最大嘛。


#include 
#include 
#include 
#include "dlib/rand/rand_kernel_abstract.h"

using namespace std;
using namespace dlib;

//svm二类分类器,调用前请修改nFeatures值;
namespace SVM{

	#define nFeatures 2
	typedef matrix sample_type;//定义数据类型;
	typedef radial_basis_kernel kernel_type;//定义核类型;

	typedef probabilistic_decision_function probabilistic_funct_type;  
	typedef normalized_function pfunct_type;

	enum Trainer{CTrainer = 1, NUTrainer = 2};
	enum LoadType {LoadSamples = 1, LoadTestData = 2};

	class SVMClassification{
	public:
		SVMClassification(){
			samples.clear();
			labels.clear();
		}
		~SVMClassification(){}
		
		bool loadData(const char* fn, int opt = LoadSamples)
		{
			if(! QFile::exists(fn))
			{
				cout << fn << "does not exist!\n";
				return false;
			}
			QFile infile(fn);
			if (!infile.open(QIODevice::ReadOnly))
			{
				cout << fn << "open error!\n";
				return false;
			}
			QTextStream _in(&infile);
			QString smsg = _in.readLine();
			QStringList slist;
			if(opt == LoadSamples)
			{
				samples.clear();
				labels.clear();
			}
			else
				testData.clear();
			

			while(! _in.atEnd())
			{
				sample_type samp;
				smsg = _in.readLine();
				slist = smsg.split(",");
				for (int i = 0; i < nFeatures; i ++)
				{
					samp(i) = slist[i+1].trimmed().toDouble();
					//cout << samp(i)<<" ";
				}
				
				if(opt == LoadSamples)
				{
					samples.push_back(samp);
					labels.push_back(slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0);
					//cout << (slist[slist.size()-1].trimmed().toInt()==1? 1.0:-1.0)< best_result(1, 2);
			best_result = 0;
			best_gamma = 0.0001, best_nu = 0.0001, best_c = 5;

			switch(opt)
			{
			case NUTrainer:
				for (double gamma = 0.00001; gamma <= 1; gamma *= 5)
				{
					for (double nu = 0.00001; nu < max_nu; nu *= 5)
					{
						trainer.set_kernel(kernel_type(gamma));
						trainer.set_nu(nu);

						cout << "gamma: " << gamma << "    nu: " << nu;
						matrix result = cross_validate_trainer(trainer, samples, labels, 10);
						cout << "     cross validation accuracy: " << result;

						if (sum(result) > sum(best_result))
						{
							best_result = result;
							best_gamma = gamma;
							best_nu = nu;
						} 
					}
				}
				cout << "\nbest gamma: " << best_gamma <<"      best nu: " << best_nu<< "      best score: "< result = cross_validate_trainer(c_trainer, samples, labels, 10);
						cout << "     cross validation accuracy: " << result;

						if (sum(result) > sum(best_result))
						{
							best_result = result;
							best_gamma = gamma;
							best_c = _c;
						} 
					}
				}
				cout << "\nbest gamma: " << best_gamma <<"      best c: " << best_c<< "      best score: "<> learned_pfunct;
			cout <<"loaded learned function from "<< fn< samples;
		std::vector labels;
		std::vector testData;
		svm_nu_trainer trainer;
		svm_c_trainer c_trainer;
		vector_normalizer normalizer;
		double best_gamma;
		double best_nu;
		double best_c;
		pfunct_type learned_pfunct; 
	protected:
	};
}



你可能感兴趣的:(算法)