OpenCV的GrabCut函数使用和源码解读

图像分割之(四)OpenCV的GrabCut函数使用和源码解读

[email protected]

http://blog.csdn.net/zouxy09

 

      上一文对GrabCut做了一个了解。OpenCV中的GrabCut算法是依据《"GrabCut" - Interactive Foreground Extraction using Iterated Graph Cuts》这篇文章来实现的。现在我对源码做了些注释,以便我们更深入的了解该算法。一直觉得论文和代码是有比较大的差别的,个人觉得脱离代码看论文,最多能看懂70%,剩下20%或者更多就需要通过阅读代码来获得了,那还有10%就和每个人的基础和知识储备相挂钩了。

      接触时间有限,若有错误,还望各位前辈指正,谢谢。原论文的一些浅解见上一博文:

          http://blog.csdn.net/zouxy09/article/details/8534954

 

一、GrabCut函数使用

      在OpenCV的源码目录的samples的文件夹下,有grabCut的使用例程,请参考:

opencv\samples\cpp\grabcut.cpp

grabCut函数的API说明如下:

void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect,

                  InputOutputArray _bgdModel, InputOutputArray _fgdModel,

                  int iterCount, int mode )

/*

****参数说明:

         img——待分割的源图像,必须是83通道(CV_8UC3)图像,在处理的过程中不会被修改;

         mask——掩码图像,如果使用掩码进行初始化,那么mask保存初始化掩码信息;在执行分割的时候,也可以将用户交互所设定的前景与背景保存到mask中,然后再传入grabCut函数;在处理结束之后,mask中会保存结果。mask只能取以下四种值:

                   GCD_BGD=0),背景;

                   GCD_FGD=1),前景;

                   GCD_PR_BGD=2),可能的背景;

                   GCD_PR_FGD=3),可能的前景。

                   如果没有手工标记GCD_BGD或者GCD_FGD,那么结果只会有GCD_PR_BGDGCD_PR_FGD

         rect——用于限定需要进行分割的图像范围,只有该矩形窗口内的图像部分才被处理;

         bgdModel——背景模型,如果为null,函数内部会自动创建一个bgdModelbgdModel必须是单通道浮点型(CV_32FC1)图像,且行数只能为1,列数只能为13x5

         fgdModel——前景模型,如果为null,函数内部会自动创建一个fgdModelfgdModel必须是单通道浮点型(CV_32FC1)图像,且行数只能为1,列数只能为13x5

         iterCount——迭代次数,必须大于0

         mode——用于指示grabCut函数进行什么操作,可选的值有:

                   GC_INIT_WITH_RECT=0),用矩形窗初始化GrabCut

                   GC_INIT_WITH_MASK=1),用掩码图像初始化GrabCut

                   GC_EVAL=2),执行分割。

*/

 

二、GrabCut源码解读

       其中源码包含了gcgraph.hpp这个构建图和max flow/min cut算法的实现文件,这个文件暂时没有解读,后面再更新了。

[cpp]  view plain copy
  1. /*M/////////////////////////////////////////////////////////////////////////////////////// 
  2. // 
  3. //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 
  4. // 
  5. //  By downloading, copying, installing or using the software you agree to this license. 
  6. //  If you do not agree to this license, do not download, install, 
  7. //  copy or use the software. 
  8. // 
  9. // 
  10. //                        Intel License Agreement 
  11. //                For Open Source Computer Vision Library 
  12. // 
  13. // Copyright (C) 2000, Intel Corporation, all rights reserved. 
  14. // Third party copyrights are property of their respective owners. 
  15. // 
  16. // Redistribution and use in source and binary forms, with or without modification, 
  17. // are permitted provided that the following conditions are met: 
  18. // 
  19. //   * Redistribution's of source code must retain the above copyright notice, 
  20. //     this list of conditions and the following disclaimer. 
  21. // 
  22. //   * Redistribution's in binary form must reproduce the above copyright notice, 
  23. //     this list of conditions and the following disclaimer in the documentation 
  24. //     and/or other materials provided with the distribution. 
  25. // 
  26. //   * The name of Intel Corporation may not be used to endorse or promote products 
  27. //     derived from this software without specific prior written permission. 
  28. // 
  29. // This software is provided by the copyright holders and contributors "as is" and 
  30. // any express or implied warranties, including, but not limited to, the implied 
  31. // warranties of merchantability and fitness for a particular purpose are disclaimed. 
  32. // In no event shall the Intel Corporation or contributors be liable for any direct, 
  33. // indirect, incidental, special, exemplary, or consequential damages 
  34. // (including, but not limited to, procurement of substitute goods or services; 
  35. // loss of use, data, or profits; or business interruption) however caused 
  36. // and on any theory of liability, whether in contract, strict liability, 
  37. // or tort (including negligence or otherwise) arising in any way out of 
  38. // the use of this software, even if advised of the possibility of such damage. 
  39. // 
  40. //M*/  
  41.   
  42. #include "precomp.hpp"  
  43. #include "gcgraph.hpp"  
  44. #include <limits>  
  45.   
  46. using namespace cv;  
  47.   
  48. /* 
  49. This is implementation of image segmentation algorithm GrabCut described in 
  50. "GrabCut — Interactive Foreground Extraction using Iterated Graph Cuts". 
  51. Carsten Rother, Vladimir Kolmogorov, Andrew Blake. 
  52.  */  
  53.   
  54. /* 
  55.  GMM - Gaussian Mixture Model 
  56. */  
  57. class GMM  
  58. {  
  59. public:  
  60.     static const int componentsCount = 5;  
  61.   
  62.     GMM( Mat& _model );  
  63.     double operator()( const Vec3d color ) const;  
  64.     double operator()( int ci, const Vec3d color ) const;  
  65.     int whichComponent( const Vec3d color ) const;  
  66.   
  67.     void initLearning();  
  68.     void addSample( int ci, const Vec3d color );  
  69.     void endLearning();  
  70.   
  71. private:  
  72.     void calcInverseCovAndDeterm( int ci );  
  73.     Mat model;  
  74.     double* coefs;  
  75.     double* mean;  
  76.     double* cov;  
  77.   
  78.     double inverseCovs[componentsCount][3][3]; //协方差的逆矩阵  
  79.     double covDeterms[componentsCount];  //协方差的行列式  
  80.   
  81.     double sums[componentsCount][3];  
  82.     double prods[componentsCount][3][3];  
  83.     int sampleCounts[componentsCount];  
  84.     int totalSampleCount;  
  85. };  
  86.   
  87. //背景和前景各有一个对应的GMM(混合高斯模型)  
  88. GMM::GMM( Mat& _model )  
  89. {  
  90.     //一个像素的(唯一对应)高斯模型的参数个数或者说一个高斯模型的参数个数  
  91.     //一个像素RGB三个通道值,故3个均值,3*3个协方差,共用一个权值  
  92.     const int modelSize = 3/*mean*/ + 9/*covariance*/ + 1/*component weight*/;  
  93.     if( _model.empty() )  
  94.     {  
  95.         //一个GMM共有componentsCount个高斯模型,一个高斯模型有modelSize个模型参数  
  96.         _model.create( 1, modelSize*componentsCount, CV_64FC1 );  
  97.         _model.setTo(Scalar(0));  
  98.     }  
  99.     else if( (_model.type() != CV_64FC1) || (_model.rows != 1) || (_model.cols != modelSize*componentsCount) )  
  100.         CV_Error( CV_StsBadArg, "_model must have CV_64FC1 type, rows == 1 and cols == 13*componentsCount" );  
  101.   
  102.     model = _model;  
  103.   
  104.     //注意这些模型参数的存储方式:先排完componentsCount个coefs,再3*componentsCount个mean。  
  105.     //再3*3*componentsCount个cov。  
  106.     coefs = model.ptr<double>(0);  //GMM的每个像素的高斯模型的权值变量起始存储指针  
  107.     mean = coefs + componentsCount; //均值变量起始存储指针  
  108.     cov = mean + 3*componentsCount;  //协方差变量起始存储指针  
  109.   
  110.     forint ci = 0; ci < componentsCount; ci++ )  
  111.         if( coefs[ci] > 0 )  
  112.              //计算GMM中第ci个高斯模型的协方差的逆Inverse和行列式Determinant  
  113.              //为了后面计算每个像素属于该高斯模型的概率(也就是数据能量项)  
  114.              calcInverseCovAndDeterm( ci );   
  115. }  
  116.   
  117. //计算一个像素(由color=(B,G,R)三维double型向量来表示)属于这个GMM混合高斯模型的概率。  
  118. //也就是把这个像素像素属于componentsCount个高斯模型的概率与对应的权值相乘再相加,  
  119. //具体见论文的公式(10)。结果从res返回。  
  120. //这个相当于计算Gibbs能量的第一个能量项(取负后)。  
  121. double GMM::operator()( const Vec3d color ) const  
  122. {  
  123.     double res = 0;  
  124.     forint ci = 0; ci < componentsCount; ci++ )  
  125.         res += coefs[ci] * (*this)(ci, color );  
  126.     return res;  
  127. }  
  128.   
  129. //计算一个像素(由color=(B,G,R)三维double型向量来表示)属于第ci个高斯模型的概率。  
  130. //具体过程,即高阶的高斯密度模型计算式,具体见论文的公式(10)。结果从res返回  
  131. double GMM::operator()( int ci, const Vec3d color ) const  
  132. {  
  133.     double res = 0;  
  134.     if( coefs[ci] > 0 )  
  135.     {  
  136.         CV_Assert( covDeterms[ci] > std::numeric_limits<double>::epsilon() );  
  137.         Vec3d diff = color;  
  138.         double* m = mean + 3*ci;  
  139.         diff[0] -= m[0]; diff[1] -= m[1]; diff[2] -= m[2];  
  140.         double mult = diff[0]*(diff[0]*inverseCovs[ci][0][0] + diff[1]*inverseCovs[ci][1][0] + diff[2]*inverseCovs[ci][2][0])  
  141.                    + diff[1]*(diff[0]*inverseCovs[ci][0][1] + diff[1]*inverseCovs[ci][1][1] + diff[2]*inverseCovs[ci][2][1])  
  142.                    + diff[2]*(diff[0]*inverseCovs[ci][0][2] + diff[1]*inverseCovs[ci][1][2] + diff[2]*inverseCovs[ci][2][2]);  
  143.         res = 1.0f/sqrt(covDeterms[ci]) * exp(-0.5f*mult);  
  144.     }  
  145.     return res;  
  146. }  
  147.   
  148. //返回这个像素最有可能属于GMM中的哪个高斯模型(概率最大的那个)  
  149. int GMM::whichComponent( const Vec3d color ) const  
  150. {  
  151.     int k = 0;  
  152.     double max = 0;  
  153.   
  154.     forint ci = 0; ci < componentsCount; ci++ )  
  155.     {  
  156.         double p = (*this)( ci, color );  
  157.         if( p > max )  
  158.         {  
  159.             k = ci;  //找到概率最大的那个,或者说计算结果最大的那个  
  160.             max = p;  
  161.         }  
  162.     }  
  163.     return k;  
  164. }  
  165.   
  166. //GMM参数学习前的初始化,主要是对要求和的变量置零  
  167. void GMM::initLearning()  
  168. {  
  169.     forint ci = 0; ci < componentsCount; ci++)  
  170.     {  
  171.         sums[ci][0] = sums[ci][1] = sums[ci][2] = 0;  
  172.         prods[ci][0][0] = prods[ci][0][1] = prods[ci][0][2] = 0;  
  173.         prods[ci][1][0] = prods[ci][1][1] = prods[ci][1][2] = 0;  
  174.         prods[ci][2][0] = prods[ci][2][1] = prods[ci][2][2] = 0;  
  175.         sampleCounts[ci] = 0;  
  176.     }  
  177.     totalSampleCount = 0;  
  178. }  
  179.   
  180. //增加样本,即为前景或者背景GMM的第ci个高斯模型的像素集(这个像素集是来用估  
  181. //计计算这个高斯模型的参数的)增加样本像素。计算加入color这个像素后,像素集  
  182. //中所有像素的RGB三个通道的和sums(用来计算均值),还有它的prods(用来计算协方差),  
  183. //并且记录这个像素集的像素个数和总的像素个数(用来计算这个高斯模型的权值)。  
  184. void GMM::addSample( int ci, const Vec3d color )  
  185. {  
  186.     sums[ci][0] += color[0]; sums[ci][1] += color[1]; sums[ci][2] += color[2];  
  187.     prods[ci][0][0] += color[0]*color[0]; prods[ci][0][1] += color[0]*color[1]; prods[ci][0][2] += color[0]*color[2];  
  188.     prods[ci][1][0] += color[1]*color[0]; prods[ci][1][1] += color[1]*color[1]; prods[ci][1][2] += color[1]*color[2];  
  189.     prods[ci][2][0] += color[2]*color[0]; prods[ci][2][1] += color[2]*color[1]; prods[ci][2][2] += color[2]*color[2];  
  190.     sampleCounts[ci]++;  
  191.     totalSampleCount++;  
  192. }  
  193.   
  194. //从图像数据中学习GMM的参数:每一个高斯分量的权值、均值和协方差矩阵;  
  195. //这里相当于论文中“Iterative minimisation”的step 2  
  196. void GMM::endLearning()  
  197. {  
  198.     const double variance = 0.01;  
  199.     forint ci = 0; ci < componentsCount; ci++ )  
  200.     {  
  201.         int n = sampleCounts[ci]; //第ci个高斯模型的样本像素个数  
  202.         if( n == 0 )  
  203.             coefs[ci] = 0;  
  204.         else  
  205.         {  
  206.             //计算第ci个高斯模型的权值系数  
  207.             coefs[ci] = (double)n/totalSampleCount;   
  208.   
  209.             //计算第ci个高斯模型的均值  
  210.             double* m = mean + 3*ci;  
  211.             m[0] = sums[ci][0]/n; m[1] = sums[ci][1]/n; m[2] = sums[ci][2]/n;  
  212.   
  213.             //计算第ci个高斯模型的协方差  
  214.             double* c = cov + 9*ci;  
  215.             c[0] = prods[ci][0][0]/n - m[0]*m[0]; c[1] = prods[ci][0][1]/n - m[0]*m[1]; c[2] = prods[ci][0][2]/n - m[0]*m[2];  
  216.             c[3] = prods[ci][1][0]/n - m[1]*m[0]; c[4] = prods[ci][1][1]/n - m[1]*m[1]; c[5] = prods[ci][1][2]/n - m[1]*m[2];  
  217.             c[6] = prods[ci][2][0]/n - m[2]*m[0]; c[7] = prods[ci][2][1]/n - m[2]*m[1]; c[8] = prods[ci][2][2]/n - m[2]*m[2];  
  218.   
  219.             //计算第ci个高斯模型的协方差的行列式  
  220.             double dtrm = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6]) + c[2]*(c[3]*c[7]-c[4]*c[6]);  
  221.             if( dtrm <= std::numeric_limits<double>::epsilon() )  
  222.             {  
  223.                 //相当于如果行列式小于等于0,(对角线元素)增加白噪声,避免其变  
  224.                 //为退化(降秩)协方差矩阵(不存在逆矩阵,但后面的计算需要计算逆矩阵)。  
  225.                 // Adds the white noise to avoid singular covariance matrix.  
  226.                 c[0] += variance;  
  227.                 c[4] += variance;  
  228.                 c[8] += variance;  
  229.             }  
  230.               
  231.             //计算第ci个高斯模型的协方差的逆Inverse和行列式Determinant  
  232.             calcInverseCovAndDeterm(ci);  
  233.         }  
  234.     }  
  235. }  
  236.   
  237. //计算协方差的逆Inverse和行列式Determinant  
  238. void GMM::calcInverseCovAndDeterm( int ci )  
  239. {  
  240.     if( coefs[ci] > 0 )  
  241.     {  
  242.         //取第ci个高斯模型的协方差的起始指针  
  243.         double *c = cov + 9*ci;  
  244.         double dtrm =  
  245.               covDeterms[ci] = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6])   
  246.                                 + c[2]*(c[3]*c[7]-c[4]*c[6]);  
  247.   
  248.         //在C++中,每一种内置的数据类型都拥有不同的属性, 使用<limits>库可以获  
  249.         //得这些基本数据类型的数值属性。因为浮点算法的截断,所以使得,当a=2,  
  250.         //b=3时 10*a/b == 20/b不成立。那怎么办呢?  
  251.         //这个小正数(epsilon)常量就来了,小正数通常为可用给定数据类型的  
  252.         //大于1的最小值与1之差来表示。若dtrm结果不大于小正数,那么它几乎为零。  
  253.         //所以下式保证dtrm>0,即行列式的计算正确(协方差对称正定,故行列式大于0)。  
  254.         CV_Assert( dtrm > std::numeric_limits<double>::epsilon() );  
  255.         //三阶方阵的求逆  
  256.         inverseCovs[ci][0][0] =  (c[4]*c[8] - c[5]*c[7]) / dtrm;  
  257.         inverseCovs[ci][1][0] = -(c[3]*c[8] - c[5]*c[6]) / dtrm;  
  258.         inverseCovs[ci][2][0] =  (c[3]*c[7] - c[4]*c[6]) / dtrm;  
  259.         inverseCovs[ci][0][1] = -(c[1]*c[8] - c[2]*c[7]) / dtrm;  
  260.         inverseCovs[ci][1][1] =  (c[0]*c[8] - c[2]*c[6]) / dtrm;  
  261.         inverseCovs[ci][2][1] = -(c[0]*c[7] - c[1]*c[6]) / dtrm;  
  262.         inverseCovs[ci][0][2] =  (c[1]*c[5] - c[2]*c[4]) / dtrm;  
  263.         inverseCovs[ci][1][2] = -(c[0]*c[5] - c[2]*c[3]) / dtrm;  
  264.         inverseCovs[ci][2][2] =  (c[0]*c[4] - c[1]*c[3]) / dtrm;  
  265.     }  
  266. }  
  267.   
  268. //计算beta,也就是Gibbs能量项中的第二项(平滑项)中的指数项的beta,用来调整  
  269. //高或者低对比度时,两个邻域像素的差别的影响的,例如在低对比度时,两个邻域  
  270. //像素的差别可能就会比较小,这时候需要乘以一个较大的beta来放大这个差别,  
  271. //在高对比度时,则需要缩小本身就比较大的差别。  
  272. //所以我们需要分析整幅图像的对比度来确定参数beta,具体的见论文公式(5)。  
  273. /* 
  274.   Calculate beta - parameter of GrabCut algorithm. 
  275.   beta = 1/(2*avg(sqr(||color[i] - color[j]||))) 
  276. */  
  277. static double calcBeta( const Mat& img )  
  278. {  
  279.     double beta = 0;  
  280.     forint y = 0; y < img.rows; y++ )  
  281.     {  
  282.         forint x = 0; x < img.cols; x++ )  
  283.         {  
  284.             //计算四个方向邻域两像素的差别,也就是欧式距离或者说二阶范数  
  285.             //(当所有像素都算完后,就相当于计算八邻域的像素差了)  
  286.             Vec3d color = img.at<Vec3b>(y,x);  
  287.             if( x>0 ) // left  >0的判断是为了避免在图像边界的时候还计算,导致越界  
  288.             {  
  289.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1);  
  290.                 beta += diff.dot(diff);  //矩阵的点乘,也就是各个元素平方的和  
  291.             }  
  292.             if( y>0 && x>0 ) // upleft  
  293.             {  
  294.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1);  
  295.                 beta += diff.dot(diff);  
  296.             }  
  297.             if( y>0 ) // up  
  298.             {  
  299.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x);  
  300.                 beta += diff.dot(diff);  
  301.             }  
  302.             if( y>0 && x<img.cols-1) // upright  
  303.             {  
  304.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1);  
  305.                 beta += diff.dot(diff);  
  306.             }  
  307.         }  
  308.     }  
  309.     if( beta <= std::numeric_limits<double>::epsilon() )  
  310.         beta = 0;  
  311.     else  
  312.         beta = 1.f / (2 * beta/(4*img.cols*img.rows - 3*img.cols - 3*img.rows + 2) ); //论文公式(5)  
  313.   
  314.     return beta;  
  315. }  
  316.   
  317. //计算图每个非端点顶点(也就是每个像素作为图的一个顶点,不包括源点s和汇点t)与邻域顶点  
  318. //的边的权值。由于是无向图,我们计算的是八邻域,那么对于一个顶点,我们计算四个方向就行,  
  319. //在其他的顶点计算的时候,会把剩余那四个方向的权值计算出来。这样整个图算完后,每个顶点  
  320. //与八邻域的顶点的边的权值就都计算出来了。  
  321. //这个相当于计算Gibbs能量的第二个能量项(平滑项),具体见论文中公式(4)  
  322. /* 
  323.   Calculate weights of noterminal vertices of graph. 
  324.   beta and gamma - parameters of GrabCut algorithm. 
  325.  */  
  326. static void calcNWeights( const Mat& img, Mat& leftW, Mat& upleftW, Mat& upW,   
  327.                             Mat& uprightW, double beta, double gamma )  
  328. {  
  329.     //gammaDivSqrt2相当于公式(4)中的gamma * dis(i,j)^(-1),那么可以知道,  
  330.     //当i和j是垂直或者水平关系时,dis(i,j)=1,当是对角关系时,dis(i,j)=sqrt(2.0f)。  
  331.     //具体计算时,看下面就明白了  
  332.     const double gammaDivSqrt2 = gamma / std::sqrt(2.0f);  
  333.     //每个方向的边的权值通过一个和图大小相等的Mat来保存  
  334.     leftW.create( img.rows, img.cols, CV_64FC1 );  
  335.     upleftW.create( img.rows, img.cols, CV_64FC1 );  
  336.     upW.create( img.rows, img.cols, CV_64FC1 );  
  337.     uprightW.create( img.rows, img.cols, CV_64FC1 );  
  338.     forint y = 0; y < img.rows; y++ )  
  339.     {  
  340.         forint x = 0; x < img.cols; x++ )  
  341.         {  
  342.             Vec3d color = img.at<Vec3b>(y,x);  
  343.             if( x-1>=0 ) // left  //避免图的边界  
  344.             {  
  345.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1);  
  346.                 leftW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff));  
  347.             }  
  348.             else  
  349.                 leftW.at<double>(y,x) = 0;  
  350.             if( x-1>=0 && y-1>=0 ) // upleft  
  351.             {  
  352.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1);  
  353.                 upleftW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff));  
  354.             }  
  355.             else  
  356.                 upleftW.at<double>(y,x) = 0;  
  357.             if( y-1>=0 ) // up  
  358.             {  
  359.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x);  
  360.                 upW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff));  
  361.             }  
  362.             else  
  363.                 upW.at<double>(y,x) = 0;  
  364.             if( x+1<img.cols && y-1>=0 ) // upright  
  365.             {  
  366.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1);  
  367.                 uprightW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff));  
  368.             }  
  369.             else  
  370.                 uprightW.at<double>(y,x) = 0;  
  371.         }  
  372.     }  
  373. }  
  374.   
  375. //检查mask的正确性。mask为通过用户交互或者程序设定的,它是和图像大小一样的单通道灰度图,  
  376. //每个像素只能取GC_BGD or GC_FGD or GC_PR_BGD or GC_PR_FGD 四种枚举值,分别表示该像素  
  377. //(用户或者程序指定)属于背景、前景、可能为背景或者可能为前景像素。具体的参考:  
  378. //ICCV2001“Interactive Graph Cuts for Optimal Boundary & Region Segmentation of Objects in N-D Images”  
  379. //Yuri Y. Boykov Marie-Pierre Jolly   
  380. /* 
  381.   Check size, type and element values of mask matrix. 
  382.  */  
  383. static void checkMask( const Mat& img, const Mat& mask )  
  384. {  
  385.     if( mask.empty() )  
  386.         CV_Error( CV_StsBadArg, "mask is empty" );  
  387.     if( mask.type() != CV_8UC1 )  
  388.         CV_Error( CV_StsBadArg, "mask must have CV_8UC1 type" );  
  389.     if( mask.cols != img.cols || mask.rows != img.rows )  
  390.         CV_Error( CV_StsBadArg, "mask must have as many rows and cols as img" );  
  391.     forint y = 0; y < mask.rows; y++ )  
  392.     {  
  393.         forint x = 0; x < mask.cols; x++ )  
  394.         {  
  395.             uchar val = mask.at<uchar>(y,x);  
  396.             if( val!=GC_BGD && val!=GC_FGD && val!=GC_PR_BGD && val!=GC_PR_FGD )  
  397.                 CV_Error( CV_StsBadArg, "mask element value must be equel"  
  398.                     "GC_BGD or GC_FGD or GC_PR_BGD or GC_PR_FGD" );  
  399.         }  
  400.     }  
  401. }  
  402.   
  403. //通过用户框选目标rect来创建mask,rect外的全部作为背景,设置为GC_BGD,  
  404. //rect内的设置为 GC_PR_FGD(可能为前景)  
  405. /* 
  406.   Initialize mask using rectangular. 
  407. */  
  408. static void initMaskWithRect( Mat& mask, Size imgSize, Rect rect )  
  409. {  
  410.     mask.create( imgSize, CV_8UC1 );  
  411.     mask.setTo( GC_BGD );  
  412.   
  413.     rect.x = max(0, rect.x);  
  414.     rect.y = max(0, rect.y);  
  415.     rect.width = min(rect.width, imgSize.width-rect.x);  
  416.     rect.height = min(rect.height, imgSize.height-rect.y);  
  417.   
  418.     (mask(rect)).setTo( Scalar(GC_PR_FGD) );  
  419. }  
  420.   
  421. //通过k-means算法来初始化背景GMM和前景GMM模型  
  422. /* 
  423.   Initialize GMM background and foreground models using kmeans algorithm. 
  424. */  
  425. static void initGMMs( const Mat& img, const Mat& mask, GMM& bgdGMM, GMM& fgdGMM )  
  426. {  
  427.     const int kMeansItCount = 10;  //迭代次数  
  428.     const int kMeansType = KMEANS_PP_CENTERS; //Use kmeans++ center initialization by Arthur and Vassilvitskii  
  429.   
  430.     Mat bgdLabels, fgdLabels; //记录背景和前景的像素样本集中每个像素对应GMM的哪个高斯模型,论文中的kn  
  431.     vector<Vec3f> bgdSamples, fgdSamples; //背景和前景的像素样本集  
  432.     Point p;  
  433.     for( p.y = 0; p.y < img.rows; p.y++ )  
  434.     {  
  435.         for( p.x = 0; p.x < img.cols; p.x++ )  
  436.         {  
  437.             //mask中标记为GC_BGD和GC_PR_BGD的像素都作为背景的样本像素  
  438.             if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD )  
  439.                 bgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );  
  440.             else // GC_FGD | GC_PR_FGD  
  441.                 fgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );  
  442.         }  
  443.     }  
  444.     CV_Assert( !bgdSamples.empty() && !fgdSamples.empty() );  
  445.       
  446.     //kmeans中参数_bgdSamples为:每行一个样本  
  447.     //kmeans的输出为bgdLabels,里面保存的是输入样本集中每一个样本对应的类标签(样本聚为componentsCount类后)  
  448.     Mat _bgdSamples( (int)bgdSamples.size(), 3, CV_32FC1, &bgdSamples[0][0] );  
  449.     kmeans( _bgdSamples, GMM::componentsCount, bgdLabels,  
  450.             TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType );  
  451.     Mat _fgdSamples( (int)fgdSamples.size(), 3, CV_32FC1, &fgdSamples[0][0] );  
  452.     kmeans( _fgdSamples, GMM::componentsCount, fgdLabels,  
  453.             TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType );  
  454.   
  455.     //经过上面的步骤后,每个像素所属的高斯模型就确定的了,那么就可以估计GMM中每个高斯模型的参数了。  
  456.     bgdGMM.initLearning();  
  457.     forint i = 0; i < (int)bgdSamples.size(); i++ )  
  458.         bgdGMM.addSample( bgdLabels.at<int>(i,0), bgdSamples[i] );  
  459.     bgdGMM.endLearning();  
  460.   
  461.     fgdGMM.initLearning();  
  462.     forint i = 0; i < (int)fgdSamples.size(); i++ )  
  463.         fgdGMM.addSample( fgdLabels.at<int>(i,0), fgdSamples[i] );  
  464.     fgdGMM.endLearning();  
  465. }  
  466.   
  467. //论文中:迭代最小化算法step 1:为每个像素分配GMM中所属的高斯模型,kn保存在Mat compIdxs中  
  468. /* 
  469.   Assign GMMs components for each pixel. 
  470. */  
  471. static void assignGMMsComponents( const Mat& img, const Mat& mask, const GMM& bgdGMM,   
  472.                                     const GMM& fgdGMM, Mat& compIdxs )  
  473. {  
  474.     Point p;  
  475.     for( p.y = 0; p.y < img.rows; p.y++ )  
  476.     {  
  477.         for( p.x = 0; p.x < img.cols; p.x++ )  
  478.         {  
  479.             Vec3d color = img.at<Vec3b>(p);  
  480.             //通过mask来判断该像素属于背景像素还是前景像素,再判断它属于前景或者背景GMM中的哪个高斯分量  
  481.             compIdxs.at<int>(p) = mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ?  
  482.                 bgdGMM.whichComponent(color) : fgdGMM.whichComponent(color);  
  483.         }  
  484.     }  
  485. }  
  486.   
  487. //论文中:迭代最小化算法step 2:从每个高斯模型的像素样本集中学习每个高斯模型的参数  
  488. /* 
  489.   Learn GMMs parameters. 
  490. */  
  491. static void learnGMMs( const Mat& img, const Mat& mask, const Mat& compIdxs, GMM& bgdGMM, GMM& fgdGMM )  
  492. {  
  493.     bgdGMM.initLearning();  
  494.     fgdGMM.initLearning();  
  495.     Point p;  
  496.     forint ci = 0; ci < GMM::componentsCount; ci++ )  
  497.     {  
  498.         for( p.y = 0; p.y < img.rows; p.y++ )  
  499.         {  
  500.             for( p.x = 0; p.x < img.cols; p.x++ )  
  501.             {  
  502.                 if( compIdxs.at<int>(p) == ci )  
  503.                 {  
  504.                     if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD )  
  505.                         bgdGMM.addSample( ci, img.at<Vec3b>(p) );  
  506.                     else  
  507.                         fgdGMM.addSample( ci, img.at<Vec3b>(p) );  
  508.                 }  
  509.             }  
  510.         }  
  511.     }  
  512.     bgdGMM.endLearning();  
  513.     fgdGMM.endLearning();  
  514. }  
  515.   
  516. //通过计算得到的能量项构建图,图的顶点为像素点,图的边由两部分构成,  
  517. //一类边是:每个顶点与Sink汇点t(代表背景)和源点Source(代表前景)连接的边,  
  518. //这类边的权值通过Gibbs能量项的第一项能量项来表示。  
  519. //另一类边是:每个顶点与其邻域顶点连接的边,这类边的权值通过Gibbs能量项的第二项能量项来表示。  
  520. /* 
  521.   Construct GCGraph 
  522. */  
  523. static void constructGCGraph( const Mat& img, const Mat& mask, const GMM& bgdGMM, const GMM& fgdGMM, double lambda,  
  524.                        const Mat& leftW, const Mat& upleftW, const Mat& upW, const Mat& uprightW,  
  525.                        GCGraph<double>& graph )  
  526. {  
  527.     int vtxCount = img.cols*img.rows;  //顶点数,每一个像素是一个顶点  
  528.     int edgeCount = 2*(4*vtxCount - 3*(img.cols + img.rows) + 2);  //边数,需要考虑图边界的边的缺失  
  529.     //通过顶点数和边数创建图。这些类型声明和函数定义请参考gcgraph.hpp  
  530.     graph.create(vtxCount, edgeCount);  
  531.     Point p;  
  532.     for( p.y = 0; p.y < img.rows; p.y++ )  
  533.     {  
  534.         for( p.x = 0; p.x < img.cols; p.x++)  
  535.         {  
  536.             // add node  
  537.             int vtxIdx = graph.addVtx();  //返回这个顶点在图中的索引  
  538.             Vec3b color = img.at<Vec3b>(p);  
  539.   
  540.             // set t-weights              
  541.             //计算每个顶点与Sink汇点t(代表背景)和源点Source(代表前景)连接的权值。  
  542.             //也即计算Gibbs能量(每一个像素点作为背景像素或者前景像素)的第一个能量项  
  543.             double fromSource, toSink;  
  544.             if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )  
  545.             {  
  546.                 //对每一个像素计算其作为背景像素或者前景像素的第一个能量项,作为分别与t和s点的连接权值  
  547.                 fromSource = -log( bgdGMM(color) );  
  548.                 toSink = -log( fgdGMM(color) );  
  549.             }  
  550.             else if( mask.at<uchar>(p) == GC_BGD )  
  551.             {  
  552.                 //对于确定为背景的像素点,它与Source点(前景)的连接为0,与Sink点的连接为lambda  
  553.                 fromSource = 0;  
  554.                 toSink = lambda;  
  555.             }  
  556.             else // GC_FGD  
  557.             {  
  558.                 fromSource = lambda;  
  559.                 toSink = 0;  
  560.             }  
  561.             //设置该顶点vtxIdx分别与Source点和Sink点的连接权值  
  562.             graph.addTermWeights( vtxIdx, fromSource, toSink );  
  563.   
  564.             // set n-weights  n-links  
  565.             //计算两个邻域顶点之间连接的权值。  
  566.             //也即计算Gibbs能量的第二个能量项(平滑项)  
  567.             if( p.x>0 )  
  568.             {  
  569.                 double w = leftW.at<double>(p);  
  570.                 graph.addEdges( vtxIdx, vtxIdx-1, w, w );  
  571.             }  
  572.             if( p.x>0 && p.y>0 )  
  573.             {  
  574.                 double w = upleftW.at<double>(p);  
  575.                 graph.addEdges( vtxIdx, vtxIdx-img.cols-1, w, w );  
  576.             }  
  577.             if( p.y>0 )  
  578.             {  
  579.                 double w = upW.at<double>(p);  
  580.                 graph.addEdges( vtxIdx, vtxIdx-img.cols, w, w );  
  581.             }  
  582.             if( p.x<img.cols-1 && p.y>0 )  
  583.             {  
  584.                 double w = uprightW.at<double>(p);  
  585.                 graph.addEdges( vtxIdx, vtxIdx-img.cols+1, w, w );  
  586.             }  
  587.         }  
  588.     }  
  589. }  
  590.   
  591. //论文中:迭代最小化算法step 3:分割估计:最小割或者最大流算法  
  592. /* 
  593.   Estimate segmentation using MaxFlow algorithm 
  594. */  
  595. static void estimateSegmentation( GCGraph<double>& graph, Mat& mask )  
  596. {  
  597.     //通过最大流算法确定图的最小割,也即完成图像的分割  
  598.     graph.maxFlow();  
  599.     Point p;  
  600.     for( p.y = 0; p.y < mask.rows; p.y++ )  
  601.     {  
  602.         for( p.x = 0; p.x < mask.cols; p.x++ )  
  603.         {  
  604.             //通过图分割的结果来更新mask,即最后的图像分割结果。注意的是,永远都  
  605.             //不会更新用户指定为背景或者前景的像素  
  606.             if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )  
  607.             {  
  608.                 if( graph.inSourceSegment( p.y*mask.cols+p.x /*vertex index*/ ) )  
  609.                     mask.at<uchar>(p) = GC_PR_FGD;  
  610.                 else  
  611.                     mask.at<uchar>(p) = GC_PR_BGD;  
  612.             }  
  613.         }  
  614.     }  
  615. }  
  616.   
  617. //最后的成果:提供给外界使用的伟大的API:grabCut   
  618. /* 
  619. ****参数说明: 
  620.     img——待分割的源图像,必须是8位3通道(CV_8UC3)图像,在处理的过程中不会被修改; 
  621.     mask——掩码图像,如果使用掩码进行初始化,那么mask保存初始化掩码信息;在执行分割 
  622.         的时候,也可以将用户交互所设定的前景与背景保存到mask中,然后再传入grabCut函 
  623.         数;在处理结束之后,mask中会保存结果。mask只能取以下四种值: 
  624.         GCD_BGD(=0),背景; 
  625.         GCD_FGD(=1),前景; 
  626.         GCD_PR_BGD(=2),可能的背景; 
  627.         GCD_PR_FGD(=3),可能的前景。 
  628.         如果没有手工标记GCD_BGD或者GCD_FGD,那么结果只会有GCD_PR_BGD或GCD_PR_FGD; 
  629.     rect——用于限定需要进行分割的图像范围,只有该矩形窗口内的图像部分才被处理; 
  630.     bgdModel——背景模型,如果为null,函数内部会自动创建一个bgdModel;bgdModel必须是 
  631.         单通道浮点型(CV_32FC1)图像,且行数只能为1,列数只能为13x5; 
  632.     fgdModel——前景模型,如果为null,函数内部会自动创建一个fgdModel;fgdModel必须是 
  633.         单通道浮点型(CV_32FC1)图像,且行数只能为1,列数只能为13x5; 
  634.     iterCount——迭代次数,必须大于0; 
  635.     mode——用于指示grabCut函数进行什么操作,可选的值有: 
  636.         GC_INIT_WITH_RECT(=0),用矩形窗初始化GrabCut; 
  637.         GC_INIT_WITH_MASK(=1),用掩码图像初始化GrabCut; 
  638.         GC_EVAL(=2),执行分割。 
  639. */  
  640. void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect,  
  641.                   InputOutputArray _bgdModel, InputOutputArray _fgdModel,  
  642.                   int iterCount, int mode )  
  643. {  
  644.     Mat img = _img.getMat();  
  645.     Mat& mask = _mask.getMatRef();  
  646.     Mat& bgdModel = _bgdModel.getMatRef();  
  647.     Mat& fgdModel = _fgdModel.getMatRef();  
  648.   
  649.     if( img.empty() )  
  650.         CV_Error( CV_StsBadArg, "image is empty" );  
  651.     if( img.type() != CV_8UC3 )  
  652.         CV_Error( CV_StsBadArg, "image mush have CV_8UC3 type" );  
  653.   
  654.     GMM bgdGMM( bgdModel ), fgdGMM( fgdModel );  
  655.     Mat compIdxs( img.size(), CV_32SC1 );  
  656.   
  657.     if( mode == GC_INIT_WITH_RECT || mode == GC_INIT_WITH_MASK )  
  658.     {  
  659.         if( mode == GC_INIT_WITH_RECT )  
  660.             initMaskWithRect( mask, img.size(), rect );  
  661.         else // flag == GC_INIT_WITH_MASK  
  662.             checkMask( img, mask );  
  663.         initGMMs( img, mask, bgdGMM, fgdGMM );  
  664.     }  
  665.   
  666.     if( iterCount <= 0)  
  667.         return;  
  668.   
  669.     if( mode == GC_EVAL )  
  670.         checkMask( img, mask );  
  671.   
  672.     const double gamma = 50;  
  673.     const double lambda = 9*gamma;  
  674.     const double beta = calcBeta( img );  
  675.   
  676.     Mat leftW, upleftW, upW, uprightW;  
  677.     calcNWeights( img, leftW, upleftW, upW, uprightW, beta, gamma );  
  678.   
  679.     forint i = 0; i < iterCount; i++ )  
  680.     {  
  681.         GCGraph<double> graph;  
  682.         assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs );  
  683.         learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM );  
  684.         constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph );  
  685.         estimateSegmentation( graph, mask );  
  686.     }  
  687. }  


 

你可能感兴趣的:(OpenCV的GrabCut函数使用和源码解读)