正交匹配追踪(OMP)与C++代码

一、前言

本文参考自两篇博文http://blog.csdn.net/scucj/article/details/7467955和http://blog.csdn.net/pi9nc/article/details/26593003, OMP(正交匹配追踪)理论早在90年代就提出来了,其为将信号分解为超完备字典上的稀疏表示的经典方法之一,这两篇博客分析得很透出,原理上这里不再重复。

二、OMP步骤

借用第一篇博文中对OMP的描述方法,可得OMP算法求解步骤如下:

正交匹配追踪(OMP)与C++代码_第1张图片

正交匹配追踪(OMP)与C++代码_第2张图片

                       

三、代码

正交匹配追踪(OMP)与C++代码_第3张图片


void Test::OrthMatchPursuit(
	Mat& target,
	float min_residual,
	int sparsity,
	//Store original dictionary
	vector& m_patches,
	//Store matched patches' coefficient  
	vector& coefficients,
	//Store matched patches  
	vector& matched_patches,
	//Store indices of matched patches  
	vector& matched_indices
	)
{
	Mat residual = target.clone();

	//the atoms' set;  
	int m_vec_dims = target.rows > target.cols ? target.rows : target.cols;
	Mat ori_phi = Mat::zeros(m_vec_dims, sparsity, CV_32FC1);
	Mat phi;

	//phi.t()*phi which is a SPD matrix  
	Mat ori_spd = Mat::ones(sparsity, sparsity, CV_32FC1);
	Mat spd = ori_spd(Rect(0, 0, 1, 1));

	//reserve enough memory.  
	matched_patches.reserve(sparsity);
	matched_indices.reserve(sparsity);

	float max_coefficient;
	int matched_index;
	vector::iterator matched_patch_it;

	for (int spars = 1;; spars++)
	{
		max_coefficient = 0;
		matched_index = 0;
		int current_index = 0;

		for (vector::iterator patch_it = m_patches.begin();
			patch_it != m_patches.end();
			++patch_it
			)
		{
			Mat& cur_vec = *patch_it;
			float coefficient = (float)cur_vec.dot(residual);

			//Find the maxmum coefficient  
			if (abs(coefficient) > abs(max_coefficient))
			{
				max_coefficient = coefficient;
				matched_patch_it = patch_it;
				matched_index = current_index;
			}
			current_index++;
		}
		matched_patches.push_back((*matched_patch_it));
		matched_indices.push_back(matched_index);

		Mat& matched_vec = (*matched_patch_it);


		//update the spd matrix via symply appending a single row and column to it.  
		if (spars > 1)
		{
			Mat v = matched_vec.t()*phi;
			float c = (float)norm(matched_vec);

			v.copyTo(ori_spd(Rect(0, spars - 1, spars - 1, 1)));
			((Mat)v.t()).copyTo(ori_spd(Rect(spars - 1, 0, 1, spars - 1)));
			*ori_spd.ptr(spars - 1, spars - 1) = c*c;

			spd = ori_spd(Rect(0, 0, spars, spars));
			
		}

		//Add the new matched patch to the vectors' set.  
		phi = ori_phi(Rect(0, 0, spars, m_vec_dims));
		matched_vec.copyTo(phi.col(spars - 1));

		Mat temp, x;
		if (spars == 1){
			phi.convertTo(phi, target.type());
			solve(phi, target, x, DECOMP_SVD);
			Mat temp = phi.t()*phi;
			temp.copyTo(ori_spd(Rect(0, 0, 1, 1)));
			
		}else{
			temp = spd.inv(DECOMP_CHOLESKY)*phi.t();
			temp.convertTo(temp, target.type());
			x = temp*target;
		}


		//A SPD matrix! Use Cholesky process to speed up.  
		//Mat x = spd.inv(DECOMP_CHOLESKY)*phi.t()*target;
		phi.convertTo(phi, x.type());
		residual = target - phi*x;

		float res_norm = (float)norm(residual);
		if (spars >= sparsity || res_norm <= min_residual)
		{
			coefficients.clear();
			coefficients.reserve(x.cols);
			x.copyTo(coefficients);

			return;
		}
	}
}


你可能感兴趣的:(OMP,OMP,C++实现,正交匹配追踪C++代码,OMP,C++代码,机器学习算法的一些理解总结)