OpenCV的GrabCut函数使用和源码解读

 上一文对GrabCut做了一个了解。OpenCV中的GrabCut算法是依据《"GrabCut" - Interactive Foreground Extraction using Iterated Graph Cuts》这篇文章来实现的。现在我对源码做了些注释,以便我们更深入的了解该算法。一直觉得论文和代码是有比较大的差别的,个人觉得脱离代码看论文,最多能看懂70%,剩下20%或者更多就需要通过阅读代码来获得了,那还有10%就和每个人的基础和知识储备相挂钩了。

      接触时间有限,若有错误,还望各位前辈指正,谢谢。原论文的一些浅解见上一博文:

          http://blog.csdn.net/zouxy09/article/details/8534954

 

一、GrabCut函数使用

      在OpenCV的源码目录的samples的文件夹下,有grabCut的使用例程,请参考:

opencv\samples\cpp\grabcut.cpp

grabCut函数的API说明如下:

void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect,

                  InputOutputArray _bgdModel, InputOutputArray _fgdModel,

                  int iterCount, int mode )

/*

****参数说明:

         img——待分割的源图像,必须是83通道(CV_8UC3)图像,在处理的过程中不会被修改;

         mask——掩码图像,如果使用掩码进行初始化,那么mask保存初始化掩码信息;在执行分割的时候,也可以将用户交互所设定的前景与背景保存到mask中,然后再传入grabCut函数;在处理结束之后,mask中会保存结果。mask只能取以下四种值:

                   GCD_BGD=0),背景;

                   GCD_FGD=1),前景;

                   GCD_PR_BGD=2),可能的背景;

                   GCD_PR_FGD=3),可能的前景。

                   如果没有手工标记GCD_BGD或者GCD_FGD,那么结果只会有GCD_PR_BGDGCD_PR_FGD

         rect——用于限定需要进行分割的图像范围,只有该矩形窗口内的图像部分才被处理;

         bgdModel——背景模型,如果为null,函数内部会自动创建一个bgdModelbgdModel必须是单通道浮点型(CV_32FC1)图像,且行数只能为1,列数只能为13x5

         fgdModel——前景模型,如果为null,函数内部会自动创建一个fgdModelfgdModel必须是单通道浮点型(CV_32FC1)图像,且行数只能为1,列数只能为13x5

         iterCount——迭代次数,必须大于0

         mode——用于指示grabCut函数进行什么操作,可选的值有:

                   GC_INIT_WITH_RECT=0),用矩形窗初始化GrabCut

                   GC_INIT_WITH_MASK=1),用掩码图像初始化GrabCut

                   GC_EVAL=2),执行分割。

*/

 

二、GrabCut源码解读

       其中源码包含了gcgraph.hpp这个构建图和max flow/min cut算法的实现文件,这个文件暂时没有解读,后面再更新了。

[cpp]  view plain copy
  1. /*M/////////////////////////////////////////////////////////////////////////////////////// 
  2. // 
  3. //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 
  4. // 
  5. //  By downloading, copying, installing or using the software you agree to this license. 
  6. //  If you do not agree to this license, do not download, install, 
  7. //  copy or use the software. 
  8. // 
  9. // 
  10. //                        Intel License Agreement 
  11. //                For Open Source Computer Vision Library 
  12. // 
  13. // Copyright (C) 2000, Intel Corporation, all rights reserved. 
  14. // Third party copyrights are property of their respective owners. 
  15. // 
  16. // Redistribution and use in source and binary forms, with or without modification, 
  17. // are permitted provided that the following conditions are met: 
  18. // 
  19. //   * Redistribution's of source code must retain the above copyright notice, 
  20. //     this list of conditions and the following disclaimer. 
  21. // 
  22. //   * Redistribution's in binary form must reproduce the above copyright notice, 
  23. //     this list of conditions and the following disclaimer in the documentation 
  24. //     and/or other materials provided with the distribution. 
  25. // 
  26. //   * The name of Intel Corporation may not be used to endorse or promote products 
  27. //     derived from this software without specific prior written permission. 
  28. // 
  29. // This software is provided by the copyright holders and contributors "as is" and 
  30. // any express or implied warranties, including, but not limited to, the implied 
  31. // warranties of merchantability and fitness for a particular purpose are disclaimed. 
  32. // In no event shall the Intel Corporation or contributors be liable for any direct, 
  33. // indirect, incidental, special, exemplary, or consequential damages 
  34. // (including, but not limited to, procurement of substitute goods or services; 
  35. // loss of use, data, or profits; or business interruption) however caused 
  36. // and on any theory of liability, whether in contract, strict liability, 
  37. // or tort (including negligence or otherwise) arising in any way out of 
  38. // the use of this software, even if advised of the possibility of such damage. 
  39. // 
  40. //M*/  
  41.   
  42. #include "precomp.hpp"  
  43. #include "gcgraph.hpp"  
  44. #include <limits>  
  45.   
  46. using namespace cv;  
  47.   
  48. /* 
  49. This is implementation of image segmentation algorithm GrabCut described in 
  50. "GrabCut — Interactive Foreground Extraction using Iterated Graph Cuts". 
  51. Carsten Rother, Vladimir Kolmogorov, Andrew Blake. 
  52.  */  
  53.   
  54. /* 
  55.  GMM - Gaussian Mixture Model 
  56. */  
  57. class GMM  
  58. {  
  59. public:  
  60.     static const int componentsCount = 5;  
  61.   
  62.     GMM( Mat& _model );  
  63.     double operator()( const Vec3d color ) const;  
  64.     double operator()( int ci, const Vec3d color ) const;  
  65.     int whichComponent( const Vec3d color ) const;  
  66.   
  67.     void initLearning();  
  68.     void addSample( int ci, const Vec3d color );  
  69.     void endLearning();  
  70.   
  71. private:  
  72.     void calcInverseCovAndDeterm( int ci );  
  73.     Mat model;  
  74.     double* coefs;  
  75.     double* mean;  
  76.     double* cov;  
  77.   
  78.     double inverseCovs[componentsCount][3][3]; //协方差的逆矩阵  
  79.     double covDeterms[componentsCount];  //协方差的行列式  
  80.   
  81.     double sums[componentsCount][3];  
  82.     double prods[componentsCount][3][3];  
  83.     int sampleCounts[componentsCount];  
  84.     int totalSampleCount;  
  85. };  
  86.   
  87. //背景和前景各有一个对应的GMM(混合高斯模型)  
  88. GMM::GMM( Mat& _model )  
  89. {  
  90.     //一个像素的(唯一对应)高斯模型的参数个数或者说一个高斯模型的参数个数  
  91.     //一个像素RGB三个通道值,故3个均值,3*3个协方差,共用一个权值 
  92.    //协方差用来度量两个随机变量的关系,如果为正,则正相关;否则,负相关
  93.     const int modelSize = 3/*mean*/ + 9/*covariance*/ + 1/*component weight*/;  
  94.     if( _model.empty() )  
  95.     {  
  96.         //一个GMM共有componentsCount个高斯模型,一个高斯模型有modelSize个模型参数  
  97.         _model.create( 1, modelSize*componentsCount, CV_64FC1 );  
  98.         _model.setTo(Scalar(0));  
  99.     }  
  100.     else if( (_model.type() != CV_64FC1) || (_model.rows != 1) || (_model.cols != modelSize*componentsCount) )  
  101.         CV_Error( CV_StsBadArg, "_model must have CV_64FC1 type, rows == 1 and cols == 13*componentsCount" );  
  102.   
  103.     model = _model;  
  104.   
  105.     //注意这些模型参数的存储方式:先排完componentsCount个coefs,再3*componentsCount个mean。  
  106.     //再3*3*componentsCount个cov。  
  107.     coefs = model.ptr<double>(0);  //GMM的每个像素的高斯模型的权值变量起始存储指针  
  108.     mean = coefs + componentsCount; //均值变量起始存储指针  
  109.     cov = mean + 3*componentsCount;  //协方差变量起始存储指针  
  110.   
  111.     forint ci = 0; ci < componentsCount; ci++ )  
  112.         if( coefs[ci] > 0 )  
  113.              //计算GMM中第ci个高斯模型的协方差的逆Inverse和行列式Determinant  
  114.              //为了后面计算每个像素属于该高斯模型的概率(也就是数据能量项)  
  115.              calcInverseCovAndDeterm( ci );   
  116. }  
  117.   
  118. //计算一个像素(由color=(B,G,R)三维double型向量来表示)属于这个GMM混合高斯模型的概率。  
  119. //也就是把这个像素像素属于componentsCount个高斯模型的概率与对应的权值相乘再相加,  
  120. //具体见上一节的公式(10.a)。结果从res返回。  
  121. //这个相当于计算Gibbs能量的第一个能量项(取负后)。  
  122. double GMM::operator()( const Vec3d color ) const  
  123. {  
  124.     double res = 0;  
  125.     forint ci = 0; ci < componentsCount; ci++ )  
  126.         res += coefs[ci] * (*this)(ci, color );  
  127.     return res;  
  128. }  
  129.   
  130. //计算一个像素(由color=(B,G,R)三维double型向量来表示)属于第ci个高斯模型的概率。  
  131. //具体过程,即高阶的高斯密度模型计算式,具体见上一节的公式(10.b)。结果从res返回  
  132. double GMM::operator()( int ci, const Vec3d color ) const  
  133. {  
  134.     double res = 0;  
  135.     if( coefs[ci] > 0 )  
  136.     {  
  137.         CV_Assert( covDeterms[ci] > std::numeric_limits<double>::epsilon() );  
  138.         Vec3d diff = color;  
  139.         double* m = mean + 3*ci;  
  140.         diff[0] -= m[0]; diff[1] -= m[1]; diff[2] -= m[2];  
  141.         double mult = diff[0]*(diff[0]*inverseCovs[ci][0][0] + diff[1]*inverseCovs[ci][1][0] + diff[2]*inverseCovs[ci][2][0])  
  142.                    + diff[1]*(diff[0]*inverseCovs[ci][0][1] + diff[1]*inverseCovs[ci][1][1] + diff[2]*inverseCovs[ci][2][1])  
  143.                    + diff[2]*(diff[0]*inverseCovs[ci][0][2] + diff[1]*inverseCovs[ci][1][2] + diff[2]*inverseCovs[ci][2][2]);  
  144.         res = 1.0f/sqrt(covDeterms[ci]) * exp(-0.5f*mult);  
  145.     }  
  146.     return res;  
  147. }  
  148.   
  149. //返回这个像素最有可能属于GMM中的哪个高斯模型(概率最大的那个)  
  150. int GMM::whichComponent( const Vec3d color ) const  
  151. {  
  152.     int k = 0;  
  153.     double max = 0;  
  154.   
  155.     forint ci = 0; ci < componentsCount; ci++ )  
  156.     {  
  157.         double p = (*this)( ci, color );  
  158.         if( p > max )  
  159.         {  
  160.             k = ci;  //找到概率最大的那个,或者说计算结果最大的那个  
  161.             max = p;  
  162.         }  
  163.     }  
  164.     return k;  
  165. }  
  166.   
  167. //GMM参数学习前的初始化,主要是对要求和的变量置零  
  168. void GMM::initLearning()  
  169. {  
  170.     forint ci = 0; ci < componentsCount; ci++)  
  171.     {  
  172.         sums[ci][0] = sums[ci][1] = sums[ci][2] = 0;  
  173.         prods[ci][0][0] = prods[ci][0][1] = prods[ci][0][2] = 0;  
  174.         prods[ci][1][0] = prods[ci][1][1] = prods[ci][1][2] = 0;  
  175.         prods[ci][2][0] = prods[ci][2][1] = prods[ci][2][2] = 0;  
  176.         sampleCounts[ci] = 0;  
  177.     }  
  178.     totalSampleCount = 0;  
  179. }  
  180.   
  181. //增加样本,即为前景或者背景GMM的第ci个高斯模型的像素集(这个像素集是来用估  
  182. //计计算这个高斯模型的参数的)增加样本像素。计算加入color这个像素后,像素集  
  183. //中所有像素的RGB三个通道的和sums(用来计算均值),还有它的prods(用来计算协方差),  
  184. //并且记录这个像素集的像素个数和总的像素个数(用来计算这个高斯模型的权值)。  
  185. void GMM::addSample( int ci, const Vec3d color )  
  186. {  
  187.     sums[ci][0] += color[0]; sums[ci][1] += color[1]; sums[ci][2] += color[2];  
  188.     prods[ci][0][0] += color[0]*color[0]; prods[ci][0][1] += color[0]*color[1]; prods[ci][0][2] += color[0]*color[2];  
  189.     prods[ci][1][0] += color[1]*color[0]; prods[ci][1][1] += color[1]*color[1]; prods[ci][1][2] += color[1]*color[2];  
  190.     prods[ci][2][0] += color[2]*color[0]; prods[ci][2][1] += color[2]*color[1]; prods[ci][2][2] += color[2]*color[2];  
  191.     sampleCounts[ci]++;  
  192.     totalSampleCount++;  
  193. }  
  194.   
  195. //从图像数据中学习GMM的参数:每一个高斯分量的权值、均值和协方差矩阵;  
  196. //这里相当于论文中“Iterative minimisation”的step 2  
  197. void GMM::endLearning()  
  198. {  
  199.     const double variance = 0.01;  
  200.     forint ci = 0; ci < componentsCount; ci++ )  
  201.     {  
  202.         int n = sampleCounts[ci]; //第ci个高斯模型的样本像素个数  
  203.         if( n == 0 )  
  204.             coefs[ci] = 0;  
  205.         else  
  206.         {  
  207.             //计算第ci个高斯模型的权值系数  
  208.             coefs[ci] = (double)n/totalSampleCount;   
  209.   
  210.             //计算第ci个高斯模型的均值  
  211.             double* m = mean + 3*ci;  
  212.             m[0] = sums[ci][0]/n; m[1] = sums[ci][1]/n; m[2] = sums[ci][2]/n;  
  213.   
  214.             //计算第ci个高斯模型的协方差  
  215.             double* c = cov + 9*ci;  
  216.             c[0] = prods[ci][0][0]/n - m[0]*m[0]; c[1] = prods[ci][0][1]/n - m[0]*m[1]; c[2] = prods[ci][0][2]/n - m[0]*m[2];  
  217.             c[3] = prods[ci][1][0]/n - m[1]*m[0]; c[4] = prods[ci][1][1]/n - m[1]*m[1]; c[5] = prods[ci][1][2]/n - m[1]*m[2];  
  218.             c[6] = prods[ci][2][0]/n - m[2]*m[0]; c[7] = prods[ci][2][1]/n - m[2]*m[1]; c[8] = prods[ci][2][2]/n - m[2]*m[2];  
  219.   
  220.             //计算第ci个高斯模型的协方差的行列式  
  221.             double dtrm = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6]) + c[2]*(c[3]*c[7]-c[4]*c[6]);  
  222.             if( dtrm <= std::numeric_limits<double>::epsilon() )  
  223.             {  
  224.                 //相当于如果行列式小于等于0,(对角线元素)增加白噪声,避免其变  
  225.                 //为退化(降秩)协方差矩阵(不存在逆矩阵,但后面的计算需要计算逆矩阵)。  
  226.                 // Adds the white noise to avoid singular covariance matrix.  
  227.                 c[0] += variance;  
  228.                 c[4] += variance;  
  229.                 c[8] += variance;  
  230.             }  
  231.               
  232.             //计算第ci个高斯模型的协方差的逆Inverse和行列式Determinant  
  233.             calcInverseCovAndDeterm(ci);  
  234.         }  
  235.     }  
  236. }  
  237.   
  238. //计算协方差的逆Inverse和行列式Determinant  
  239. void GMM::calcInverseCovAndDeterm( int ci )  
  240. {  
  241.     if( coefs[ci] > 0 )  
  242.     {  
  243.         //取第ci个高斯模型的协方差的起始指针  
  244.         double *c = cov + 9*ci;  
  245.         double dtrm =  
  246.               covDeterms[ci] = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6])   
  247.                                 + c[2]*(c[3]*c[7]-c[4]*c[6]);  
  248.   
  249.         //在C++中,每一种内置的数据类型都拥有不同的属性, 使用<limits>库可以获  
  250.         //得这些基本数据类型的数值属性。因为浮点算法的截断,所以使得,当a=2,  
  251.         //b=3时 10*a/b == 20/b不成立。那怎么办呢?  
  252.         //这个小正数(epsilon)常量就来了,小正数通常为可用给定数据类型的  
  253.         //大于1的最小值与1之差来表示。若dtrm结果不大于小正数,那么它几乎为零。  
  254.         //所以下式保证dtrm>0,即行列式的计算正确(协方差对称正定,故行列式大于0)。  
  255.         CV_Assert( dtrm > std::numeric_limits<double>::epsilon() );  
  256.         //三阶方阵的求逆  
  257.         inverseCovs[ci][0][0] =  (c[4]*c[8] - c[5]*c[7]) / dtrm;  
  258.         inverseCovs[ci][1][0] = -(c[3]*c[8] - c[5]*c[6]) / dtrm;  
  259.         inverseCovs[ci][2][0] =  (c[3]*c[7] - c[4]*c[6]) / dtrm;  
  260.         inverseCovs[ci][0][1] = -(c[1]*c[8] - c[2]*c[7]) / dtrm;  
  261.         inverseCovs[ci][1][1] =  (c[0]*c[8] - c[2]*c[6]) / dtrm;  
  262.         inverseCovs[ci][2][1] = -(c[0]*c[7] - c[1]*c[6]) / dtrm;  
  263.         inverseCovs[ci][0][2] =  (c[1]*c[5] - c[2]*c[4]) / dtrm;  
  264.         inverseCovs[ci][1][2] = -(c[0]*c[5] - c[2]*c[3]) / dtrm;  
  265.         inverseCovs[ci][2][2] =  (c[0]*c[4] - c[1]*c[3]) / dtrm;  
  266.     }  
  267. }  
  268.   
  269. //计算beta,也就是Gibbs能量项中的第二项(平滑项)中的指数项的beta,用来调整  
  270. //高或者低对比度时,两个邻域像素的差别的影响的,例如在低对比度时,两个邻域  
  271. //像素的差别可能就会比较小,这时候需要乘以一个较大的beta来放大这个差别,  
  272. //在高对比度时,则需要缩小本身就比较大的差别。  
  273. //所以我们需要分析整幅图像的对比度来确定参数beta,具体的见论文公式(5)。  
  274. /* 
  275.   Calculate beta - parameter of GrabCut algorithm. 
  276.   beta = 1/(2*avg(sqr(||color[i] - color[j]||))) 
  277. */  
  278. static double calcBeta( const Mat& img )  
  279. {  
  280.     double beta = 0;  
  281.     forint y = 0; y < img.rows; y++ )  
  282.     {  
  283.         forint x = 0; x < img.cols; x++ )  
  284.         {  
  285.             //计算四个方向邻域两像素的差别,也就是欧式距离或者说二阶范数  
  286.             //(当所有像素都算完后,就相当于计算八邻域的像素差了)  
  287.             Vec3d color = img.at<Vec3b>(y,x);  
  288.             if( x>0 ) // left  >0的判断是为了避免在图像边界的时候还计算,导致越界  
  289.             {  
  290.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1);  
  291.                 beta += diff.dot(diff);  //矩阵的点乘,也就是各个元素平方的和  
  292.             }  
  293.             if( y>0 && x>0 ) // upleft  
  294.             {  
  295.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1);  
  296.                 beta += diff.dot(diff);  
  297.             }  
  298.             if( y>0 ) // up  
  299.             {  
  300.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x);  
  301.                 beta += diff.dot(diff);  
  302.             }  
  303.             if( y>0 && x<img.cols-1) // upright  
  304.             {  
  305.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1);  
  306.                 beta += diff.dot(diff);  
  307.             }  
  308.         }  
  309.     }  
  310.     if( beta <= std::numeric_limits<double>::epsilon() )  
  311.         beta = 0;  
  312.     else  
  313.         beta = 1.f / (2 * beta/(4*img.cols*img.rows - 3*img.cols - 3*img.rows + 2) ); //论文公式(5)  
  314.   
  315.     return beta;  
  316. }  
  317.   
  318. //计算图每个非端点顶点(也就是每个像素作为图的一个顶点,不包括源点s和汇点t)与邻域顶点  
  319. //的边的权值。由于是无向图,我们计算的是八邻域,那么对于一个顶点,我们计算四个方向就行,  
  320. //在其他的顶点计算的时候,会把剩余那四个方向的权值计算出来。这样整个图算完后,每个顶点  
  321. //与八邻域的顶点的边的权值就都计算出来了。  
  322. //这个相当于计算Gibbs能量的第二个能量项(平滑项),具体见论文中公式(4)  
  323. /* 
  324.   Calculate weights of noterminal vertices of graph. 
  325.   beta and gamma - parameters of GrabCut algorithm. 
  326.  */  
  327. static void calcNWeights( const Mat& img, Mat& leftW, Mat& upleftW, Mat& upW,   
  328.                             Mat& uprightW, double beta, double gamma )  
  329. {  
  330.     //gammaDivSqrt2相当于公式(4)中的gamma * dis(i,j)^(-1),那么可以知道,  
  331.     //当i和j是垂直或者水平关系时,dis(i,j)=1,当是对角关系时,dis(i,j)=sqrt(2.0f)。  
  332.     //具体计算时,看下面就明白了  
  333.     const double gammaDivSqrt2 = gamma / std::sqrt(2.0f);  
  334.     //每个方向的边的权值通过一个和图大小相等的Mat来保存  
  335.     leftW.create( img.rows, img.cols, CV_64FC1 );  
  336.     upleftW.create( img.rows, img.cols, CV_64FC1 );  
  337.     upW.create( img.rows, img.cols, CV_64FC1 );  
  338.     uprightW.create( img.rows, img.cols, CV_64FC1 );  
  339.     forint y = 0; y < img.rows; y++ )  
  340.     {  
  341.         forint x = 0; x < img.cols; x++ )  
  342.         {  
  343.             Vec3d color = img.at<Vec3b>(y,x);  
  344.             if( x-1>=0 ) // left  //避免图的边界  
  345.             {  
  346.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1);  
  347.                 leftW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff));  
  348.             }  
  349.             else  
  350.                 leftW.at<double>(y,x) = 0;  
  351.             if( x-1>=0 && y-1>=0 ) // upleft  
  352.             {  
  353.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1);  
  354.                 upleftW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff));  
  355.             }  
  356.             else  
  357.                 upleftW.at<double>(y,x) = 0;  
  358.             if( y-1>=0 ) // up  
  359.             {  
  360.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x);  
  361.                 upW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff));  
  362.             }  
  363.             else  
  364.                 upW.at<double>(y,x) = 0;  
  365.             if( x+1<img.cols && y-1>=0 ) // upright  
  366.             {  
  367.                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1);  
  368.                 uprightW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff));  
  369.             }  
  370.             else  
  371.                 uprightW.at<double>(y,x) = 0;  
  372.         }  
  373.     }  
  374. }  
  375.   
  376. //检查mask的正确性。mask为通过用户交互或者程序设定的,它是和图像大小一样的单通道灰度图,  
  377. //每个像素只能取GC_BGD or GC_FGD or GC_PR_BGD or GC_PR_FGD 四种枚举值,分别表示该像素  
  378. //(用户或者程序指定)属于背景、前景、可能为背景或者可能为前景像素。具体的参考:  
  379. //ICCV2001“Interactive Graph Cuts for Optimal Boundary & Region Segmentation of Objects in N-D Images”  
  380. //Yuri Y. Boykov Marie-Pierre Jolly   
  381. /* 
  382.   Check size, type and element values of mask matrix. 
  383.  */  
  384. static void checkMask( const Mat& img, const Mat& mask )  
  385. {  
  386.     if( mask.empty() )  
  387.         CV_Error( CV_StsBadArg, "mask is empty" );  
  388.     if( mask.type() != CV_8UC1 )  
  389.         CV_Error( CV_StsBadArg, "mask must have CV_8UC1 type" );  
  390.     if( mask.cols != img.cols || mask.rows != img.rows )  
  391.         CV_Error( CV_StsBadArg, "mask must have as many rows and cols as img" );  
  392.     forint y = 0; y < mask.rows; y++ )  
  393.     {  
  394.         forint x = 0; x < mask.cols; x++ )  
  395.         {  
  396.             uchar val = mask.at<uchar>(y,x);  
  397.             if( val!=GC_BGD && val!=GC_FGD && val!=GC_PR_BGD && val!=GC_PR_FGD )  
  398.                 CV_Error( CV_StsBadArg, "mask element value must be equel"  
  399.                     "GC_BGD or GC_FGD or GC_PR_BGD or GC_PR_FGD" );  
  400.         }  
  401.     }  
  402. }  
  403.   
  404. //通过用户框选目标rect来创建mask,rect外的全部作为背景,设置为GC_BGD,  
  405. //rect内的设置为 GC_PR_FGD(可能为前景)  
  406. /* 
  407.   Initialize mask using rectangular. 
  408. */  
  409. static void initMaskWithRect( Mat& mask, Size imgSize, Rect rect )  
  410. {  
  411.     mask.create( imgSize, CV_8UC1 );  
  412.     mask.setTo( GC_BGD );  
  413.   
  414.     rect.x = max(0, rect.x);  
  415.     rect.y = max(0, rect.y);  
  416.     rect.width = min(rect.width, imgSize.width-rect.x);  
  417.     rect.height = min(rect.height, imgSize.height-rect.y);  
  418.   
  419.     (mask(rect)).setTo( Scalar(GC_PR_FGD) );  
  420. }  
  421.   
  422. //通过k-means算法来初始化背景GMM和前景GMM模型  
  423. /* 
  424.   Initialize GMM background and foreground models using kmeans algorithm. 
  425. */  
  426. static void initGMMs( const Mat& img, const Mat& mask, GMM& bgdGMM, GMM& fgdGMM )  
  427. {  
  428.     const int kMeansItCount = 10;  //迭代次数  
  429.     const int kMeansType = KMEANS_PP_CENTERS; //Use kmeans++ center initialization by Arthur and Vassilvitskii  
  430.   
  431.     Mat bgdLabels, fgdLabels; //记录背景和前景的像素样本集中每个像素对应GMM的哪个高斯模型,论文中的kn  
  432.     vector<Vec3f> bgdSamples, fgdSamples; //背景和前景的像素样本集  
  433.     Point p;  
  434.     for( p.y = 0; p.y < img.rows; p.y++ )  
  435.     {  
  436.         for( p.x = 0; p.x < img.cols; p.x++ )  
  437.         {  
  438.             //mask中标记为GC_BGD和GC_PR_BGD的像素都作为背景的样本像素  
  439.             if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD )  
  440.                 bgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );  
  441.             else // GC_FGD | GC_PR_FGD  
  442.                 fgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );  
  443.         }  
  444.     }  
  445.     CV_Assert( !bgdSamples.empty() && !fgdSamples.empty() );  
  446.       
  447.     //kmeans中参数_bgdSamples为:每行一个样本  
  448.     //kmeans的输出为bgdLabels,里面保存的是输入样本集中每一个样本对应的类标签(样本聚为componentsCount类后)  
  449.     Mat _bgdSamples( (int)bgdSamples.size(), 3, CV_32FC1, &bgdSamples[0][0] );  
  450.     kmeans( _bgdSamples, GMM::componentsCount, bgdLabels,  
  451.             TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType );  
  452.     Mat _fgdSamples( (int)fgdSamples.size(), 3, CV_32FC1, &fgdSamples[0][0] );  
  453.     kmeans( _fgdSamples, GMM::componentsCount, fgdLabels,  
  454.             TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType );  
  455.   
  456.     //经过上面的步骤后,每个像素所属的高斯模型就确定的了,那么就可以估计GMM中每个高斯模型的参数了。  
  457.     bgdGMM.initLearning();  
  458.     forint i = 0; i < (int)bgdSamples.size(); i++ )  
  459.         bgdGMM.addSample( bgdLabels.at<int>(i,0), bgdSamples[i] );  
  460.     bgdGMM.endLearning();  
  461.   
  462.     fgdGMM.initLearning();  
  463.     forint i = 0; i < (int)fgdSamples.size(); i++ )  
  464.         fgdGMM.addSample( fgdLabels.at<int>(i,0), fgdSamples[i] );  
  465.     fgdGMM.endLearning();  
  466. }  
  467.   
  468. //论文中:迭代最小化算法step 1:为每个像素分配GMM中所属的高斯模型,kn保存在Mat compIdxs中  
  469. /* 
  470.   Assign GMMs components for each pixel. 
  471. */  
  472. static void assignGMMsComponents( const Mat& img, const Mat& mask, const GMM& bgdGMM,   
  473.                                     const GMM& fgdGMM, Mat& compIdxs )  
  474. {  
  475.     Point p;  
  476.     for( p.y = 0; p.y < img.rows; p.y++ )  
  477.     {  
  478.         for( p.x = 0; p.x < img.cols; p.x++ )  
  479.         {  
  480.             Vec3d color = img.at<Vec3b>(p);  
  481.             //通过mask来判断该像素属于背景像素还是前景像素,再判断它属于前景或者背景GMM中的哪个高斯分量  
  482.             compIdxs.at<int>(p) = mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ?  
  483.                 bgdGMM.whichComponent(color) : fgdGMM.whichComponent(color);  
  484.         }  
  485.     }  
  486. }  
  487.   
  488. //论文中:迭代最小化算法step 2:从每个高斯模型的像素样本集中学习每个高斯模型的参数  
  489. /* 
  490.   Learn GMMs parameters. 
  491. */  
  492. static void learnGMMs( const Mat& img, const Mat& mask, const Mat& compIdxs, GMM& bgdGMM, GMM& fgdGMM )  
  493. {  
  494.     bgdGMM.initLearning();  
  495.     fgdGMM.initLearning();  
  496.     Point p;  
  497.     forint ci = 0; ci < GMM::componentsCount; ci++ )  
  498.     {  
  499.         for( p.y = 0; p.y < img.rows; p.y++ )  
  500.         {  
  501.             for( p.x = 0; p.x < img.cols; p.x++ )  
  502.             {  
  503.                 if( compIdxs.at<int>(p) == ci )  
  504.                 {  
  505.                     if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD )  
  506.                         bgdGMM.addSample( ci, img.at<Vec3b>(p) );  
  507.                     else  
  508.                         fgdGMM.addSample( ci, img.at<Vec3b>(p) );  
  509.                 }  
  510.             }  
  511.         }  
  512.     }  
  513.     bgdGMM.endLearning();  
  514.     fgdGMM.endLearning();  
  515. }  
  516.   
  517. //通过计算得到的能量项构建图,图的顶点为像素点,图的边由两部分构成,  
  518. //一类边是:每个顶点与Sink汇点t(代表背景)和源点Source(代表前景)连接的边,  
  519. //这类边的权值通过Gibbs能量项的第一项能量项来表示。  
  520. //另一类边是:每个顶点与其邻域顶点连接的边,这类边的权值通过Gibbs能量项的第二项能量项来表示。  
  521. /* 
  522.   Construct GCGraph 
  523. */  
  524. static void constructGCGraph( const Mat& img, const Mat& mask, const GMM& bgdGMM, const GMM& fgdGMM, double lambda,  
  525.                        const Mat& leftW, const Mat& upleftW, const Mat& upW, const Mat& uprightW,  
  526.                        GCGraph<double>& graph )  
  527. {  
  528.     int vtxCount = img.cols*img.rows;  //顶点数,每一个像素是一个顶点  
  529.     int edgeCount = 2*(4*vtxCount - 3*(img.cols + img.rows) + 2);  //边数,需要考虑图边界的边的缺失  
  530.     //通过顶点数和边数创建图。这些类型声明和函数定义请参考gcgraph.hpp  
  531.     graph.create(vtxCount, edgeCount);  
  532.     Point p;  
  533.     for( p.y = 0; p.y < img.rows; p.y++ )  
  534.     {  
  535.         for( p.x = 0; p.x < img.cols; p.x++)  
  536.         {  
  537.             // add node  
  538.             int vtxIdx = graph.addVtx();  //返回这个顶点在图中的索引  
  539.             Vec3b color = img.at<Vec3b>(p);  
  540.   
  541.             // set t-weights              
  542.             //计算每个顶点与Sink汇点t(代表背景)和源点Source(代表前景)连接的权值。  
  543.             //也即计算Gibbs能量(每一个像素点作为背景像素或者前景像素)的第一个能量项  
  544.             double fromSource, toSink;  
  545.             if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )  
  546.             {  
  547.                 //对每一个像素计算其作为背景像素或者前景像素的第一个能量项,作为分别与t和s点的连接权值  
  548.                 fromSource = -log( bgdGMM(color) );  
  549.                 toSink = -log( fgdGMM(color) );  
  550.             }  
  551.             else if( mask.at<uchar>(p) == GC_BGD )  
  552.             {  
  553.                 //对于确定为背景的像素点,它与Source点(前景)的连接为0,与Sink点的连接为lambda  
  554.                 fromSource = 0;  
  555.                 toSink = lambda;  
  556.             }  
  557.             else // GC_FGD  
  558.             {  
  559.                 fromSource = lambda;  
  560.                 toSink = 0;  
  561.             }  
  562.             //设置该顶点vtxIdx分别与Source点和Sink点的连接权值  
  563.             graph.addTermWeights( vtxIdx, fromSource, toSink );  
  564.   
  565.             // set n-weights  n-links  
  566.             //计算两个邻域顶点之间连接的权值。  
  567.             //也即计算Gibbs能量的第二个能量项(平滑项)  
  568.             if( p.x>0 )  
  569.             {  
  570.                 double w = leftW.at<double>(p);  
  571.                 graph.addEdges( vtxIdx, vtxIdx-1, w, w );  
  572.             }  
  573.             if( p.x>0 && p.y>0 )  
  574.             {  
  575.                 double w = upleftW.at<double>(p);  
  576.                 graph.addEdges( vtxIdx, vtxIdx-img.cols-1, w, w );  
  577.             }  
  578.             if( p.y>0 )  
  579.             {  
  580.                 double w = upW.at<double>(p);  
  581.                 graph.addEdges( vtxIdx, vtxIdx-img.cols, w, w );  
  582.             }  
  583.             if( p.x<img.cols-1 && p.y>0 )  
  584.             {  
  585.                 double w = uprightW.at<double>(p);  
  586.                 graph.addEdges( vtxIdx, vtxIdx-img.cols+1, w, w );  
  587.             }  
  588.         }  
  589.     }  
  590. }  
  591.   
  592. //论文中:迭代最小化算法step 3:分割估计:最小割或者最大流算法  
  593. /* 
  594.   Estimate segmentation using MaxFlow algorithm 
  595. */  
  596. static void estimateSegmentation( GCGraph<double>& graph, Mat& mask )  
  597. {  
  598.     //通过最大流算法确定图的最小割,也即完成图像的分割  
  599.     graph.maxFlow();  
  600.     Point p;  
  601.     for( p.y = 0; p.y < mask.rows; p.y++ )  
  602.     {  
  603.         for( p.x = 0; p.x < mask.cols; p.x++ )  
  604.         {  
  605.             //通过图分割的结果来更新mask,即最后的图像分割结果。注意的是,永远都  
  606.             //不会更新用户指定为背景或者前景的像素  
  607.             if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )  
  608.             {  
  609.                 if( graph.inSourceSegment( p.y*mask.cols+p.x /*vertex index*/ ) )  
  610.                     mask.at<uchar>(p) = GC_PR_FGD;  
  611.                 else  
  612.                     mask.at<uchar>(p) = GC_PR_BGD;  
  613.             }  
  614.         }  
  615.     }  
  616. }  
  617.   
  618. //最后的成果:提供给外界使用的伟大的API:grabCut   
  619. /* 
  620. ****参数说明: 
  621.     img——待分割的源图像,必须是8位3通道(CV_8UC3)图像,在处理的过程中不会被修改; 
  622.     mask——掩码图像,如果使用掩码进行初始化,那么mask保存初始化掩码信息;在执行分割 
  623.         的时候,也可以将用户交互所设定的前景与背景保存到mask中,然后再传入grabCut函 
  624.         数;在处理结束之后,mask中会保存结果。mask只能取以下四种值: 
  625.         GCD_BGD(=0),背景; 
  626.         GCD_FGD(=1),前景; 
  627.         GCD_PR_BGD(=2),可能的背景; 
  628.         GCD_PR_FGD(=3),可能的前景。 
  629.         如果没有手工标记GCD_BGD或者GCD_FGD,那么结果只会有GCD_PR_BGD或GCD_PR_FGD; 
  630.     rect——用于限定需要进行分割的图像范围,只有该矩形窗口内的图像部分才被处理; 
  631.     bgdModel——背景模型,如果为null,函数内部会自动创建一个bgdModel;bgdModel必须是 
  632.         单通道浮点型(CV_32FC1)图像,且行数只能为1,列数只能为13x5; 
  633.     fgdModel——前景模型,如果为null,函数内部会自动创建一个fgdModel;fgdModel必须是 
  634.         单通道浮点型(CV_32FC1)图像,且行数只能为1,列数只能为13x5; 
  635.     iterCount——迭代次数,必须大于0; 
  636.     mode——用于指示grabCut函数进行什么操作,可选的值有: 
  637.         GC_INIT_WITH_RECT(=0),用矩形窗初始化GrabCut; 
  638.         GC_INIT_WITH_MASK(=1),用掩码图像初始化GrabCut; 
  639.         GC_EVAL(=2),执行分割。 
  640. */  
  641. void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect,  
  642.                   InputOutputArray _bgdModel, InputOutputArray _fgdModel,  
  643.                   int iterCount, int mode )  
  644. {  
  645.     Mat img = _img.getMat();  
  646.     Mat& mask = _mask.getMatRef();  
  647.     Mat& bgdModel = _bgdModel.getMatRef();  
  648.     Mat& fgdModel = _fgdModel.getMatRef();  
  649.   
  650.     if( img.empty() )  
  651.         CV_Error( CV_StsBadArg, "image is empty" );  
  652.     if( img.type() != CV_8UC3 )  
  653.         CV_Error( CV_StsBadArg, "image mush have CV_8UC3 type" );  
  654.   
  655.     GMM bgdGMM( bgdModel ), fgdGMM( fgdModel );  
  656.     Mat compIdxs( img.size(), CV_32SC1 );  
  657.   
  658.     if( mode == GC_INIT_WITH_RECT || mode == GC_INIT_WITH_MASK )  
  659.     {  
  660.         if( mode == GC_INIT_WITH_RECT )  
  661.             initMaskWithRect( mask, img.size(), rect );  
  662.         else // flag == GC_INIT_WITH_MASK  
  663.             checkMask( img, mask );  
  664.         initGMMs( img, mask, bgdGMM, fgdGMM );  
  665.     }  
  666.   
  667.     if( iterCount <= 0)  
  668.         return;  
  669.   
  670.     if( mode == GC_EVAL )  
  671.         checkMask( img, mask );  
  672.   
  673.     const double gamma = 50;  
  674.     const double lambda = 9*gamma;  
  675.     const double beta = calcBeta( img );  
  676.   
  677.     Mat leftW, upleftW, upW, uprightW;  
  678.     calcNWeights( img, leftW, upleftW, upW, uprightW, beta, gamma );  
  679.   
  680.     forint i = 0; i < iterCount; i++ )  
  681.     {  
  682.         GCGraph<double> graph;  
  683.         assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs );  
  684.         learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM );  
  685.         constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph );  
  686.         estimateSegmentation( graph, mask );  
  687.     }  
  688. }  



GrabCut资源汇总:


https://mywebspace.wisc.edu/pwang6/personal/
is a very high quality code. Do not give out the source code though. The user cannot change the setting of the parameters like the "k" in the GMM. However, it works very smoothly in vista and XP both 32 bit and 64 bit.

http://www.cs.cmu.edu/~mohitg/segmentation.htm
Provides the source code for C++ and Matlab. However, the quality is not so good. It could crash sometimes, and the parameter setting is not straight-forward.

http://research.justintalbot.org/papers/GrabCut.zip
Provides the C++ code for GrabCut implementation. This source code has very high quality, and it provides the step-by-step of the GMM learning and Graph Cutting. The user could see the change of the energy and the color modes.

The user needs to modify the source code to change the setting of the "k" in the GMM.

To compile this code, OpenCV(http://opencv.willowgarage.com/wiki/) and the GLUT (http://www.opengl.org/resources/libraries/glut/) both need to be installed.
VC and Linux both could compile the source code and get it running.

The details of this implementation are described in:
http://students.cs.byu.edu/~jtalbot/research/Grabcut.pdf

Have fun!


你可能感兴趣的:(OpenCV的GrabCut函数使用和源码解读)