支持向量机(SVM)中最核心的是4个字——“支持向量”,一旦在两类或多累样本集中定位到某些特定的点作为支持向量,就可以依据这些支持向量计算出来分类超平面,再依据超平面对类别进行归类划分就是水到渠成的事了。有必要回顾一下什么是支持向量机中的支持向量。 上图中需要对红色和蓝色的两类训练样本进行区分,实现绿线是决策面(超平面),最靠近决策面的2个实心红色样本和1个实心蓝色样本分别是两类训练样本的支持向量,决策面所在的位置是使得两类支持向量与决策面之间的间隔都达到最大时决策面所处的位置。
一般情况下,训练样本都会存在噪声,这就导致其中一类样本的一个或多个样本跑到了决策面的另一边,掺杂到另一类样本中。针对这种情况,SVM加入了松弛变量(惩罚变量)来应对,确保这些噪声样本不会被作为支持向量,而不管它们离超平面的距离有多近。包括SVM中的另一个重要概念“核函数”,也是为训练样本支持向量的确定提供支持的。 在OpenCV中,SVM的训练、归类流程如下: 1. 获取训练样本 SVM是一种有监督的学习分类方法,所以对于给出的训练样本,要明确每个样本的归类是0还是1,即每个样本都需要标注一个确切的类别标签,提供给SVM训练使用。对于样本的特征,以及特征的维度,SVM并没有限定,可以使用如Haar、角点、Sift、Surf、直方图等各种特征作为训练样本的表述参与SVM的训练。Opencv要求训练数据存储在float类型的Mat结构中。
1. SVM的核心思想
SVM的分类思想本质上和线性回归LR分类方法类似,就是求出一组权重系数,在线性表示之后可以分类。我们先使用一组trainging set来训练SVM中的权重系数,然后可以对testingset进行分类。
说的更加更大上一些:SVM就是先训练出一个分割超平面separation hyperplane, 然后该平面就是分类的决策边界,分在平面两边的就是两类。显然,经典的SVM算法只适用于两类分类问题,当然,经过改进之后,SVM也可以适用于多类分类问题。
我们希望找到离分隔超平面最近的点,确保它们离分隔面的距离尽可能远。这里点到分隔面的距离被称为间隔margin. 我们希望这个margin尽可能的大。支持向量support vector就是离分隔超平面最近的那些点,我们要最大化支持向量到分隔面的距离。
那么为了达到上面的目的,我们就要解决这样的一个问题:如何计算一个点到分隔面的距离?这里我们可以借鉴几何学中点到直线的距离,需要变动的是我们这里是点到超平面的距离。具体转换过程如下:
代码:
#include
#include
#include
#include
#include
#include
#include
using namespace std;
using namespace cv;
using namespace ml;
cv::Mat trainData;
cv::Mat trainLabel;
void get_data1();
void get_data2();
int main() {
get_data1();
get_data2();
trainData.convertTo(trainData, CV_32FC1);
trainLabel.convertTo(trainLabel, CV_32SC1);
//vector test_set;
//get_test(io, test_set);
Ptr model = SVM::create();
model->setType(SVM::C_SVC); //SVM类型
model->setKernel(SVM::LINEAR); //核函数,这里使用线性核
model->setTermCriteria(cv::TermCriteria(cv::TermCriteria::MAX_ITER, 100, 1e-6));
//model->train(trainData, ROW_SAMPLE, trainLabel);//只训练不生成xml
Ptr tData = TrainData::create(trainData, ROW_SAMPLE, trainLabel);
//训练生成xml
std::cout << "SVM: start train ..." << endl;
model->train(tData);
std::cout << "SVM: train success ..." << endl;
model->save("svm.xml");
waitKey();
getchar();
return 0;
}
void get_data1() {
ifstream fin("alist.txt", ios::in);
string s;
while (getline(fin, s)) {
if (s.length() != 0) {
string ss;
ss = "C:\\Users\\wangz\\Desktop\\\apple\\" + s;
Mat m2 = imread(ss, 0);
if (m2.empty()) {
cout << "fail" << endl;
}
Mat temp;
resize(m2, temp, Size(256, 256));
//imshow("12", m2);
trainData.push_back(temp.reshape(0, 1)); //连续放入Mat容器中
trainLabel.push_back(1);
//trainLabel.push_back(Mat(1, 1, CV_32SC1, 0)); //trainData.convertTo(trainData, CV_32FC1);
}
}
}
void get_data2() {
ifstream fin("mlist.txt", ios::in);
string s;
//Mat mm(29, 256 * 256, CV_8UC1);
while (getline(fin, s)) {
string ss;
if (s.length() != 0) {
ss = "C:\\Users\\wangz\\Desktop\\svm\\microsoft\\" + s;
Mat m2 = imread(ss, 0);
if (m2.empty()) {
cout << "fail" << endl;
}
Mat temp;
resize(m2, temp, Size(256, 256));
//imshow("12", m2);
trainData.push_back(temp.reshape(0, 1)); //连续放入Mat容器中
trainLabel.push_back(-1);
//trainLabel.push_back(Mat(1, 1, CV_32SC1, 1)); //trainData.convertTo(trainData, CV_32FC1);
}
}
}
/*
ifstream fin1(path1, ios::in);
string s1;
while (getline(fin1, s1)) {
if (s1.length() != 0) {
string slable;
slable = s1.substr(s1.find(" "));
int index = atoi(slable.c_str()); //slable - '0';
cout << index << endl;
trainLabels.push_back(Mat(1, 1, CV_32SC1, &index));
}
}
}
*/
//样本数据必须是CV_32FC1类型。opencv3版本决定的
//样本标签必须是CV_32SC1,opencv3后从int数组转换为CV_32SC1类型,而opencv2是从float数据转换。
代码二:
#include
#include
#include
#include
#include
#include
#include
using namespace std;
using namespace cv;
using namespace ml;
void get_data1(Mat& trainimage,vector& trainlab);
void get_data2(Mat& trainimage,vector& trainlab);
int main() {
Mat classes;
Mat traindata, trainimg;
vectortrainlables;
get_data1(trainimg,trainlables);
get_data2(trainimg, trainlables);
Mat(trainimg).copyTo(traindata);
traindata.convertTo(traindata, CV_32FC1);
Mat(trainlables).copyTo(classes);
classes.convertTo(classes, CV_32SC1);
//vector test_set;
//get_test(io, test_set);
Ptr model = SVM::create();
model->setType(SVM::C_SVC); //SVM类型
model->setKernel(SVM::LINEAR); //核函数,这里使用线性核
model->setTermCriteria(cv::TermCriteria(cv::TermCriteria::MAX_ITER, 100, 1e-6));
//model->train(trainData, ROW_SAMPLE, trainLabel);//只训练不生成xml
Ptr tData = TrainData::create(traindata, ROW_SAMPLE, classes);
//训练生成xml
std::cout << "SVM: start train ..." << endl;
model->train(tData);
std::cout << "SVM: train success ..." << endl;
Mat tset=imread("123.jpg",0);
resize(tset, tset, Size(256, 256));
tset.reshape(0, 1);
double k= model->predict(tset);
cout << k;
//model->save("svm.xml");
//
waitKey();
getchar();
return 0;
}
void get_data1(Mat& trainimage, vector& trainlab) {
ifstream fin("alist.txt", ios::in);
string s;
while (getline(fin, s)) {
if (s.length() != 0) {
string ss;
ss = "C:\\Users\\wangz\\Desktop\\\apple\\" + s;
Mat m2 = imread(ss, 0);
if (m2.empty()) {
cout << "fail" << endl;
}
Mat temp;
resize(m2, temp, Size(256, 256));
//imshow("12", m2);
trainimage.push_back(temp.reshape(0, 1)); //连续放入Mat容器中
trainlab.push_back(1);
//trainLabel.push_back(Mat(1, 1, CV_32SC1, 0)); //trainData.convertTo(trainData, CV_32FC1);
}
}
cout << trainimage.size() << endl;
cout << trainlab.size() << endl;
}
void get_data2(Mat& trainimage,vector& trainlab) {
ifstream fin("mlist.txt", ios::in);
string s;
//Mat mm(29, 256 * 256, CV_8UC1);
while (getline(fin, s)) {
string ss;
if (s.length() != 0) {
ss = "C:\\Users\\wangz\\Desktop\\svm\\microsoft\\" + s;
Mat m2 = imread(ss, 0);
if (m2.empty()) {
cout << "fail" << endl;
}
Mat temp;
resize(m2, temp, Size(256, 256));
//imshow("12", m2);
trainimage.push_back(temp.reshape(0, 1)); //连续放入Mat容器中
trainlab.push_back(-1);
//trainLabel.push_back(Mat(1, 1, CV_32SC1, 1)); //trainData.convertTo(trainData, CV_32FC1);
}
}
cout << trainimage.size() << endl;
cout << trainlab.size() << endl;
}
/*
ifstream fin1(path1, ios::in);
string s1;
while (getline(fin1, s1)) {
if (s1.length() != 0) {
string slable;
slable = s1.substr(s1.find(" "));
int index = atoi(slable.c_str()); //slable - '0';
cout << index << endl;
trainLabels.push_back(Mat(1, 1, CV_32SC1, &index));
}
}
}
*/
//样本数据必须是CV_32FC1类型。opencv3版本决定的
//样本标签必须是CV_32SC1,opencv3后从int数组转换为CV_32SC1类型,而opencv2是从float数据转换。
第一次用opencv训练svm出问题是很正常的如下:
百度解决方法如下:
http://www.it1352.com/481183.html
https://blog.csdn.net/galileoyuyu/article/details/82083673
参考博客:https://www.cnblogs.com/br170525/p/9236479.html
https://blog.csdn.net/mao_hui_fei/article/details/80455538
https://blog.csdn.net/im6520/article/details/75240435
https://blog.csdn.net/a1111h/article/details/72568970
https://blog.csdn.net/sinat_34474705/article/details/80502789?utm_source=blogxgwz0
https://blog.csdn.net/xuan_zizizi/article/details/71102018
hog训练的:https://blog.csdn.net/u013419097/article/details/80253977