OpenCV3.x实现KNN算法(K近邻算法),并保存训练模型

OpenCV3.x实现KNN算法(K近邻算法),并保存训练模型

     【尊重原创,转载请注明出处】https://blog.csdn.net/guyuealian/article/details/80241473

   OpenCV 3.x中cv::ml::Knearest类可以实现K-最近邻(KNN)算法,其详细用法可以参考官方说明文档:https://docs.opencv.org/3.2.0/dd/de1/classcv_1_1ml_1_1KNearest.html

(1)、cv::ml::Knearest类:继承自cv::ml::StateModel,而cv::ml::StateModel又继承自cv::Algorithm;
(2)、create函数:为static,new一个KNearestImpl用来创建一个KNearest对象;
(3)、setDefaultK/getDefaultK函数:在预测时,设置/获取的K值;
(4)、setIsClassifier/getIsClassifier函数:设置/获取应用KNN是进行分类还是回归;
(5)、setEmax/getEmax函数:在使用KDTree算法时,设置/获取Emax参数值;
(6)、setAlgorithmType/getAlgorithmType函数:设置/获取KNN算法类型,目前支持两种:brute_force和KDTree;
(7)、findNearest函数:根据输入预测分类/回归结果。

关于KNN算法思路可以参考: http://blog.csdn.net/fengbingchun/article/details/78464169  

      下面是本人使用OpenCV3.2实现的KNN算法,其中利用save()方法把训练数据保存下来,测试时重新加载load()训练模型,这样可以实现单独的测试,而不需要重新训练数据:

#include "stdafx.h"  
#include     
#include     
#include     
#include   
using namespace cv;
using namespace cv::ml;
using namespace std;
int main()

{
	float labels[10] = { 0.0, 1.0, 1.0, 2.0,2.0,0.0, 1.0,1.0, 2.0,2.0 };
	Mat labelsMat(10, 1, CV_32FC1, labels);
	// Set up training data  
	float trainArray[10][3] = { { 510, 510,10 },{ 405, 10,510 },{ 501, 45,420 },{ 10,20, 510 },{ 35,45,515 },{ 540,420,40 },{ 380,30,300 },{ 400,70,500 },{ 30,60,410 },{ 54,23,543 } };
	Mat trainDataMat(10, 3, CV_32FC1, trainArray);


	/*******************************************训练过程******************************************/
	//保存训练模型(在KNN中实质上是保存训练样本的原始数据)  
	string knnPath = "D:/SmartAlbum/image1/knn.xml";
	Ptr kclassifier = KNearest::create();
	Ptr trainData;
	trainData = TrainData::create(trainDataMat, SampleTypes::ROW_SAMPLE, labelsMat);
	kclassifier->setIsClassifier(true);
	kclassifier->setAlgorithmType(KNearest::Types::BRUTE_FORCE);
	kclassifier->setDefaultK(1);
	kclassifier->train(trainData);
	kclassifier->save(knnPath);//会把trainDataMat的原始数据全部保存为*.xml文件  

	/*******************************************测试过程******************************************/
        //加载训练模型(在KNN中,实质上就是加载训练样本的原始数据)  
	const int K = 4;//testModel->getDefaultK()  
	Ptr testModel = StatModel::load(knnPath);
	Mat sampleMat = (Mat_(1, 3) << 310, 5, 339);//测试样本  
	Mat matResults(0, 0, CV_32F);//保存测试结果  
	testModel->findNearest(sampleMat, K, matResults);//knn分类预测  
	cout << "matResults=" << matResults << endl;


	system("pause");
	waitKey();
}

你可能感兴趣的:(机器学习,OpenCV)