OpenCv-C++-03-高斯混合模型GMM

1、与KMeans相比,属于软分类;
2、实现方法--------期望最大化(E-M);
3、停止条件--------收敛或规定次数达到。

OpenCV代码实现:

#include
#include

using namespace cv;
using namespace cv::ml;
using namespace std;

int main(int argc, char** argv)
{
	Mat img = Mat::zeros(500, 500, CV_8UC3);  //定义一张500*500的8位3通道图片

	RNG rng(12345);
	Scalar colorTab[] = {    //定义颜色数组
		Scalar(0,0,255),
		Scalar(255,0,255),
		Scalar(0,255,0),
		Scalar(255,0,0),
		Scalar(0,255,255),
	};
	int numCluster = rng.uniform(2, 5);  //4个分类
	printf("number of cluster:%d\n", numCluster);  //打印4个分类
	cout << "----------------------------------------" << endl;
	int sampleCount = rng.uniform(1, 1000); //1000个随机点
	Mat points(sampleCount, 2, CV_32FC1);

	Mat labels;
	Mat centers;

	//产生多高斯部分的随机采样点
	for (int k = 0; k < numCluster; k++)
	{
		Point center;
		center.x = rng.uniform(0, img.cols);
		center.y = rng.uniform(0, img.rows);

		cout << "x=" << center.x << "  " << "y=" << center.y << endl;
		Mat pointChunk = points.rowRange(k*sampleCount / numCluster,
			k == numCluster - 1 ?
			sampleCount : (k + 1)*sampleCount / numCluster); //定义随机散点
		//rng.fill函数,会以center点为中心,产生高斯分布的随机点(位置点),并把位置点保存在矩阵pointChunk中。
		rng.fill(pointChunk, RNG::NORMAL, Scalar(center.x, center.y), Scalar(img.cols*0.05, img.rows*0.05));
	}
	randShuffle(points, 1, &rng);//打乱points中值,第二个参数表示随机交换元素的数量的缩放因子,总的交换次数dst.rows*dst.cols*iterFactor,第三个参数是个随机发生器,决定选哪两个元素交换。
	Ptrem_model = EM::create();
	em_model->setClustersNumber(numCluster);
	em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
	em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));
	em_model->trainEM(points, noArray(), labels, noArray());

	//--------------------------
	//分类
	Mat sample(1, 2, CV_32FC1);
	for (int row = 0; row < img.rows; row++)
	{
		for (int col = 0; col < img.cols; col++)
		{
			sample.at(0) = (float)col;
			sample.at(1) = (float)row;
			int response = cvRound(em_model->predict2(sample, noArray())[1]); //cvRound 返回跟参数最接近的整数值
			Scalar c = colorTab[response];
			circle(img, Point(col, row), 1, c*0.75, -1);
		}

	}
	//绘制分类的点
	for (int i = 0; i < sampleCount; i++)
	{
		Point p(cvRound(points.at(i, 0)), cvRound(points.at(i, 1)));
		circle(img, p, 1, colorTab[labels.at(i)], -1);
	}
	imshow("final result", img);
	waitKey(0);
	return 0;
}


运行结果:
OpenCv-C++-03-高斯混合模型GMM_第1张图片

基于图像分割的GMM:

#include
#include

using namespace cv;
using namespace cv::ml;
using namespace std;

int main(int argc, char* argv)
{
	Mat src = imread("F:/test/toux.jpg",1);
	if (src.empty())
	{
		cout << "图像文件未找到!" << endl;
		return -1;
	}
	imshow("input image", src);
	const Scalar colors[] = {
	Scalar(255, 0, 0),
	Scalar(0, 255, 0),
	Scalar(0, 0, 255),
	Scalar(255, 255, 0)
	};
	int clustersNumber = 3;
	int width = src.cols;
	int height = src.rows;
	int channel = src.channels();
	int all_samples = width * height; //所有的样本点总数
	Mat points(all_samples,channel, CV_64FC1);
	Mat labels;
	Mat result = Mat::zeros(src.size(), CV_8UC3);
	// 图像RGB像素数据转换为样本数据 
	int index = 0;
	for (int row = 0; row < src.rows; row++)
	{
		for (int col = 0; col < src.cols; col++)
		{
			index = row * width + col;
			Vec3b bgr = src.at(row, col);
			points.at(index, 0) = bgr[0];
			points.at(index, 1) = bgr[1];
			points.at(index, 2) = bgr[2];
		}
	}
	//EM 训练
	Ptrem_model = EM::create();
	em_model->setClustersNumber(clustersNumber);
	em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
	em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));
	em_model->trainEM(points, noArray(), labels, noArray());

	// 对每个像素标记颜色与显示
	Mat sample(channel,1,CV_64FC1);
	int r=0, g=0, b = 0;
	for (int row = 0; row < src.rows; row++)
	{
		for (int col = 0; col < src.cols; col++)
		{
			index = row * width + col;
			b = src.at(row, col)[0];
			g = src.at(row, col)[1];
			r = src.at(row, col)[2];
			sample.at(0) = b;
			sample.at(1) = g;
			sample.at(2) = r;
			int response = cvRound(em_model->predict2(sample, noArray())[1]);

			Scalar c = colors[response];
			result.at(row, col)[0] = c[0];
			result.at(row, col)[1] = c[1];
			result.at(row, col)[2] = c[2];
		
		}
	
	}
	imshow("GMM Result",result);
	waitKey(0);
	return 0;
}

运行结果:
OpenCv-C++-03-高斯混合模型GMM_第2张图片

你可能感兴趣的:(OpenCv-C++学习记录)