OpenCV3 SVM训练与预测程序

#include 
#include 
#include "opencv2/imgcodecs.hpp"
#include 
#include 
#include 

using namespace cv;
using namespace cv::ml;

/*训练部分*/
void train()
{
	// Set up training data
	int labels[4] = { 1, -1, -1, -1 };
	Mat labelsMat(4, 1, CV_32SC1, labels);

	float trainingData[4][2] = { { 501, 10 },{ 255, 10 },{ 501, 255 },{ 10, 501 } };
	Mat trainingDataMat(4, 2, CV_32FC1, trainingData);

	// Set up SVM's parameters
	Ptr svm = cv::ml::SVM::create();
	svm->setType(cv::ml::SVM::C_SVC);
	svm->setKernel(cv::ml::SVM::RBF);
	svm->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6));
	svm->setGamma(0.01);
	svm->setC(800);  //经验系数
	svm->setP(0.1);

	std::cout << "C为:" << svm->getC() << std::endl;

	// Train the SVM
	svm->train(trainingDataMat, cv::ml::ROW_SAMPLE, labelsMat);

	svm->save("f:\\train_model.xml"); //保存模型
}

/*预测部分*/
void predict()
{
	String path = "f:\\train_model.xml";
	FileStorage svm_fs(path, FileStorage::READ); //读取文件
	if (svm_fs.isOpened())
	{
		//testdata
		int labels[4] = { 1, -1, -1, -1 };
		Mat labelsMat(4, 1, CV_32SC1, labels);

		float testingData[4][2] = { { 501, 10 },{ 255, 10 },{ 501, 255 },{ 10, 501 } };
		Mat testingDataMat(4, 2, CV_32FC1, testingData);

		//Ptr svm = ml::SVM::create();
		//svm->load(path.c_str());	//从文件加载,这样是不对的—_—

		Ptr svm = ml::SVM::load(path.c_str()); 	//从文件加载

		std::cout << "C为:" << svm->getC() << std::endl;  //读取一个参数检测是否加载成功
		
		Mat result;
		for (int i = 0; i < 4; i++)
		{
			Mat sample = testingDataMat.row(i);
			float result = svm->predict(sample);
			std::cout << "结果为:" << result << std::endl;

		}
	}
}

int main(int, char**)
{
	
	train();
	
	predict();
}

       由于刚接触之前调试遇到了加载模型文件错误的问题,这段使用OpenCV3的代码是根据OpenCV2的例程改写的,遇到了模型文件加载错误的问题,具体表现为load以后svm并没有得到xml文件中的参数值,并且在predict时会报错。

       最后发现直接load就可以不需要新建一个svm对象,load本身就可以创建svm了

 /** @brief Loads and creates a serialized svm from a file
     *
     * Use SVM::save to serialize and store an SVM to disk.
     * Load the SVM from this file again, by calling this function with the path to the file.
     *
     * @param filepath path to serialized svm
     */
    CV_WRAP static Ptr load(const String& filepath);

是真滴菜。。。

你可能感兴趣的:(SVM,OpenCV)