libSVM是一个非常有名的SVM开源库,最近我在做分类任务,最后需要用到SVM进行分类,可是网上对于libSVM的介绍大多是matlab的,还有就是使用DOS命令调用的,直接使用libSVM的函数进行编程的介绍非常少,我来大体介绍一下我使用的情况吧。
我对于libSVM的了解也不是很清楚,只是单纯的利用他做训练和识别而已。
struct svm_problem
{
int n; //记录样本总数
double *y; //记录样本所属类别
struct svm_node **x; //存储所有样本的特征,二维数组,一行存一个样本的所有特征
};
其中svm_node类型的定义如下:
struct svm_node //用来存储输入空间中的单个特征
{
int index; //该特征在特征空间中的维度编号
double value; //该特征的值
};
借用网上的一张图进行表示:
#include "svm.h"
#include
#include
#include
#include
class ClassificationSVM
{
public:
ClassificationSVM();
~ClassificationSVM();
void train(const std::string& modelFileName);
void predict(const std::string& featureaFileName, const std::string& modelFileName);
private:
void setParam();
void readTrainData(const std::string& featureFileName);
private:
svm_parameter param;
svm_problem prob;//all the data for train
std::list dataList;//list of features of all the samples
std::list typeList;//list of type of all the samples
int sampleNum;
//bool* judgeRight;
};
其中 setParam函数设置如下所示:
void ClassificationSVM::setParam()
{
param.svm_type = C_SVC;
param.kernel_type = RBF;
param.degree = 3;
param.gamma = 0.5;
param.coef0 = 0;
param.nu = 0.5;
param.cache_size = 40;
param.C = 500;
param.eps = 1e-3;
param.p = 0.1;
param.shrinking = 1;
param.nr_weight = 0;
param.weight = NULL;
param.weight_label = NULL;
}
void ClassificationSVM::readTrainData(const string& featureFileName)
{
FILE *fp = fopen(featureFileName.c_str(), "r");
if (fp == NULL)
{
cout << "open feature file error!" << endl;
return;
}
fseek(fp, 0L, SEEK_END);
long end = ftell(fp);
fseek(fp, 0L, SEEK_SET);
long start = ftell(fp);
//读取文件,直到文件末尾
while (start != end)
{
//FEATUREDIM是自定义变量,表示特征的维度
svm_node* features = new svm_node[FEATUREDIM + 1];//因为需要结束标记,因此申请空间时特征维度+1
for (int k = 0; k < FEATUREDIM; k++)
{
double value = 0;
fscanf(fp, "%lf", &value);
features[k].index = k + 1;//特征标号,从1开始
features[k].value = value;//特征值
}
features[FEATUREDIM].index = -1;//结束标记
char c;
fscanf(fp, "\n", &c);
char name[100];
fgets(name, 100, fp);
name[strlen(name) - 1] = '\0';
//negative sample type is 0
int type = 0;
//positive sample type is 1
if (featureFileName == "PositiveFeatures.txt")
type = 1;
dataList.push_back(features);
typeList.push_back(type);
sampleNum++;
start = ftell(fp);
}
fclose(fp);
}
其中dataList和typeList分别存放特征值和该特征样本对应的标号(正或负)。
void ClassificationSVM::train(const string& modelFileName)
{
cout << "reading positivie features..." << endl;
readTrainData("PositiveFeatures.txt");
cout << "reading negative features..." << endl;
readTrainData("NegativeFeatures.txt");
cout << sampleNum << endl;
prob.l = sampleNum;//number of training samples
prob.x = new svm_node *[prob.l];//features of all the training samples
prob.y = new double[prob.l];//type of all the training samples
int index = 0;
while (!dataList.empty())
{
prob.x[index] = dataList.front();
prob.y[index] = typeList.front();
dataList.pop_front();
typeList.pop_front();
index++;
}
cout << "start training" << endl;
svm_model *svmModel = svm_train(&prob, ¶m);
cout << "save model" << endl;
svm_save_model(modelFileName.c_str(), svmModel);
cout << "done!" << endl;
}
prob是svm_problem类型的对象,就是把之前读取的特征全部放入svm_problem的对象中。
void ClassificationSVM::predict(const string& featureFileName, const string& modelFileName)
{
std::vector judgeRight;
svm_model *svmModel = svm_load_model(modelFileName.c_str());
FILE *fp;
if ((fp = fopen(featureFileName.c_str(), "rt")) == NULL)
return;
fseek(fp, 0L, SEEK_END);
long end = ftell(fp);
fseek(fp, 0L, SEEK_SET);
long start = ftell(fp);
while (start != end)
{
svm_node* input = new svm_node[FEATUREDIM + 1];
for (int k = 0; k
分类时的代码与之前的代码非常类似,不多做赘述。