机器学习之-利用svm(支持向量机)分类(opencv3)

svm分类算法在opencv3中有了很大的变动,取消了CvSVMParams这个类,因此在参数设定上会有些改变。

opencv中的svm分类代码,来源于libsvm。

int main(int argc, char** argv)
{
	// visual representation
	int width = 512;
	int height = 512;
	cv::Mat image = cv::Mat::zeros(height, width, CV_8UC3);

	// training data
	int labels[4] = { 1, -1, -1, -1 };
	float trainingData[4][2] = { { 501, 10 },{ 255, 10 },{ 501, 255 },{ 10, 501 } };
	cv::Mat trainingDataMat(4, 2, CV_32FC1, trainingData);
	cv::Mat labelsMat(4, 1, CV_32SC1, labels);

	// initial SVM
	cv::Ptr svm = cv::ml::SVM::create();
	svm->setType(cv::ml::SVM::Types::C_SVC);
	svm->setKernel(cv::ml::SVM::KernelTypes::LINEAR);
	svm->setTermCriteria(cv::TermCriteria(cv::TermCriteria::MAX_ITER, 100, 1e-6));

	// train operation
	svm->train(trainingDataMat, cv::ml::SampleTypes::ROW_SAMPLE, labelsMat);

	// prediction
	cv::Vec3b green(0, 255, 0);
	cv::Vec3b blue(255, 0, 0);
	for (int i = 0; i < image.rows; i++)
	{
		for (int j = 0; j < image.cols; j++)
		{
			cv::Mat sampleMat = (cv::Mat_(1, 2) << j, i);
			float respose = svm->predict(sampleMat);
			if (respose == 1)
				image.at(i, j) = green;
			else if (respose == -1)
				image.at(i, j) = blue;
		}
	}

	int thickness = -1;
	int lineType = cv::LineTypes::LINE_8;

	cv::circle(image, cv::Point(501, 10), 5, cv::Scalar(0, 0, 0), thickness, lineType);
	cv::circle(image, cv::Point(255, 10), 5, cv::Scalar(255, 255, 255), thickness, lineType);
	cv::circle(image, cv::Point(501, 255), 5, cv::Scalar(255, 255, 255), thickness, lineType);
	cv::circle(image, cv::Point(10, 501), 5, cv::Scalar(255, 255, 255), thickness, lineType);

	thickness = 2;
	lineType = cv::LineTypes::LINE_8;

	cv::Mat sv = svm->getSupportVectors();
	for (int i = 0; i < sv.rows; i++)
	{
		const float* v = sv.ptr(i);
		cv::circle(image, cv::Point((int)v[0], (int)v[1]), 6, cv::Scalar(128, 128, 128), thickness, lineType);
	}


	cv::imshow("SVM Simple Example", image);


	cv::waitKey(0);
	return 0;
}

机器学习之-利用svm(支持向量机)分类(opencv3)_第1张图片

你可能感兴趣的:(机器学习)