grabcut in one-cut 一种好用快速的图像分割算法

      2013年iCCV上的这篇论文,提出了一种快速的基于简单交互的分割算法,本篇博文是对该论文的解读。

Tang M, Gorelick L, Veksler O, et al. GrabCut inOne Cut[C]// IEEE International Conference on Computer Vision. IEEE ComputerSociety, 2013:1769-1776. 

      grabcut in onecut 基于传统的graph cut分割方法,这是一种非常流行地能量优化算法,这类方法把图像分割问题与凸的最小割问题相关联,首先用一个无向图G=,表示要分割的图像,V,E分别是顶点(vertex)和边(edge)的几何。与传统的图有所不同,graph cuts在普通图的基础上加了两个顶点,这两个顶点分别记做source和sink,统称为终端顶点。这样在Graphs cuts中就有了两种顶点,也就有了两种边。

      第一种顶点和边是:第一种普通顶点对应于图像中的每个像素。每两个邻域顶点(对应于图像中每两个邻域像素)的连接就是一条边。这种边也叫n-links。

      第二种顶点和边是:每个普通顶点和这2个终端顶点之间都有连接,组成第二种边。这种边也叫t-links。

grabcut in one-cut 一种好用快速的图像分割算法_第1张图片

图1 s-t图


      上图就是一个图像对应的S-T图,每个像素对应于图中的一个相应顶点,在这些顶点之外还有source顶点与sink顶点。蓝色和红色的边表示t-links,黄色的边表示n-links。在图像分割中,s顶点一般表示为前景目标,t顶点一般表示为背景目标。

       每条边都有权重,graph cut 中的cut是指图中的边集合的一个子集,cut中的所有边的权重和被叫做cost(代价)。

      Graph Cuts中的Cuts是指这样一个边的集合,很显然这些边集合包括了上面2种边,该集合中所有边的断开会导致残留”S”和”T”图的分开,所以就称为“割”。如果一个割,它的边的所有权值之和最小,那么这个就称为最小割,也就是图割的结果。而福特-富克森定理表明,网路的最大流max flow与最小割min cut相等。所以由Boykov和Kolmogorov发明的max-flow/min-cut算法就可以用来获得s-t图的最小割。这个最小割把图的顶点划分为两个不相交的子集S和T,其中s ∈S,t∈ T和S∪T=V 。这两个子集就对应于图像的前景像素集和背景像素集,那就相当于完成了图像分割。

grabcut in one-cut 一种好用快速的图像分割算法_第2张图片

图2 s-t最小割示意图


      边的权重的确立,遵循这样一种原则,前景与背景的分界处的权值应当最小,最小化图割是用最小化能量函数得到。

      公式中,L表示图割,R(L)为区域项,B(L)为边界项,a是权重因子,表示区域项与边界项所占的比例差别。区域项往往由下面的公式表示

      该能量项表示为像素P分配标签 的惩罚, 表示为像素p分配标签 的惩罚,该能量项的值往往通过对比像素p的灰度与给定的目标的灰度直方图来获得。

边界项

      由于边界两侧点的像素值差别往往比较大,因此,边界项的作用就是当两邻域像素的差别很大时,边界项的值应当最小。

在grabcut in one cut论文中,作者将区域项替换成了下式,有效的避免NP-hard问题

      上式可以转换成下面的式子

      Ω表示的是在一个bin中像素的数量, 表示的是bin中属于前景的像素数量, 表示的是属于背景的像素数量。

grabcut in one-cut 一种好用快速的图像分割算法_第3张图片

图3 附加顶点示意图


      同时作者给出了实现方法,通过增加辅助节点 ,k表示第k个bin将辅助节点与相应bin中的值相连,同时设置权值为1,这样在最小化前景背景时,当边界项的权值最小,同时也会相应的最小化辅助节点与相应边的连接权重,也就是最小化了区域项。


参考资料:

http://blog.csdn.net/zouxy09/article/details/8532111

作者提供的代码下载地址:

http://vision.csd.uwo.ca/wiki/vision/upload/7/77/OneCutWithSeeds_v1.03.zip

本人代码下载:

http://download.csdn.net/download/zhangyumengs/10237656



下面给出本人改写的部分qt代码,加入了注释,方便阅读。

同时加入了新的功能,可以进行裁剪图像,方形分割。

onecut.h

#ifndef ONECUT_H
#define ONECUT_H

#include 
#include "ui_onecut.h"

#include 
#include 
#include 
#include 

#include 
#include 
#include 
#include "graph.h"
#include "qmessagebox.h"
//自定义控件
#include "onecutlabel.h"
#include "cutimagelabel.h"
#include "rectseglabel.h"

using namespace std;
using namespace cv;

#define NEIGHBORHOOD_4_TYPE 1;

const int NEIGHBORHOOD = NEIGHBORHOOD_4_TYPE;

class onecut : public QWidget
{
	Q_OBJECT

public:
	onecut(QWidget *parent = 0);
	~onecut();
//___________________________________________________________________________________________________

	Mat inputImg, showImg, binPerPixelImg, showEdgesImg, segMask, segShowImg;

	Mat fgScribbleMask, bgScribbleMask;
	//用于撤销上一次操作
	Mat fgScribbleMask_last, bgScribbleMask_last;
	int lastSegState = 1;
	Mat showImg_last;
	
	int numUsedBins = 0;
	float varianceSquared = 0;
	int scribbleRadius = 10;

	float bha_slope = 0.5f;
	int numBinsPerChannel = 16;
	float EDGE_STRENGTH_WEIGHT = 0.95f;

	const float INT32_CONST = 1000;
	const float HARD_CONSTRAINT_CONST = 1000;

	int  init(Mat src);

	void destroyAll();
	// 为每一个像素计算index
	void generateBinIndex(Mat& bin, Mat& inImg, int binschannel, int& binsNotEmpty);
	//计算高斯分布的方差
	void generateEdgeVariance(Mat& inputImg, Mat& showEdgesImg, float& varianceSquared);


	typedef Graph GraphType;
	GraphType *myGraph;


private slots:
	void onSegImage();
	void onMouseMoveFinish(Mat bgScribbleMask,
		Mat fgScribbleMask,
		Mat showImg);
	void onFinish();
	void onCutImage();
	void onRectSeg();
	void onConfirmCut();
	void onDrawImage();
	void onLineWidthChanged(int);
signals:
	void okClicked(Mat result);

private:
	Ui::onecut ui;

	//裁剪相关变量
	cutImageLabel* imageLable = NULL;
	rectSegLabel* rectLabel = NULL;
	Mat rectCutImage(const Mat& src, Rect rect);

	void showImage(Mat image);
	QImage cvMatToQImage(Mat& src);
protected:
	void keyPressEvent(QKeyEvent  *event);
};

#endif

onecut.cpp

#include "onecut.h"

onecut::onecut(QWidget *parent)
	: QWidget(parent)
{
	ui.setupUi(this);
	connect(ui.label_show, SIGNAL(mouseMoveFinish(Mat, Mat, Mat)),
		this, SLOT(onMouseMoveFinish(Mat ,Mat ,Mat)));
	connect(ui.button_seg, SIGNAL(clicked()),
		this, SLOT(onSegImage()));
	connect(ui.button_ok, SIGNAL(clicked()),
		this, SLOT(onFinish()));
	//裁剪涂抹相关
	connect(ui.button_cut, SIGNAL(clicked()),
		this, SLOT(onCutImage()));
	connect(ui.button_confirmcut, SIGNAL(clicked()),
		this, SLOT(onConfirmCut()));
	connect(ui.button_draw, SIGNAL(clicked()),
		this, SLOT(onDrawImage()));
	connect(ui.button_rectseg, SIGNAL(clicked()),
		this, SLOT(onRectSeg()));

	//改变绘制线宽度
	ui.slider_linewidth->setRange(2, 10);
	ui.slider_linewidth->setValue(10);
	connect(ui.slider_linewidth, SIGNAL(valueChanged(int)),
		this, SLOT(onLineWidthChanged(int)));

}

onecut::~onecut()
{

}

void onecut::destroyAll()
{

	// clear all data
	fgScribbleMask.release();
	bgScribbleMask.release();
	inputImg.release();
	showImg.release();
	showEdgesImg.release();
	binPerPixelImg.release();
	segMask.release();
	segShowImg.release();

	delete myGraph;
}

int onecut::init(Mat src)
{
	// 初始化Mat
	inputImg = src.clone();
	this->showImage(inputImg);
	showImg = inputImg.clone();
	segShowImg = inputImg.clone();

	// 检查输入的合理性
	if (!inputImg.data)
	{
		return -1;
	}

	// 初始化涂鸦
	fgScribbleMask.create(2, inputImg.size, CV_8UC1);
	fgScribbleMask = 0;
	bgScribbleMask.create(2, inputImg.size, CV_8UC1);
	bgScribbleMask = 0;

	segMask.create(2, inputImg.size, CV_8UC1);
	segMask = 0;
	showEdgesImg.create(2, inputImg.size, CV_32FC1);
	showEdgesImg = 0;
	binPerPixelImg.create(2, inputImg.size, CV_32F);

	// numBinsPerChannel = 16 numUsedBins 表示的是一共有多少个bin的频数不为0 
	generateBinIndex(binPerPixelImg, inputImg, numBinsPerChannel, numUsedBins);

	// 计算方差
	generateEdgeVariance(inputImg, showEdgesImg, varianceSquared);

	//相当于添加了一些辅助节点
	myGraph = new GraphType( inputImg.rows * inputImg.cols + numUsedBins,12 * inputImg.rows * inputImg.cols);
	int currNodeId = myGraph->add_node((int)inputImg.cols * inputImg.rows + numUsedBins);

	for (int i = 0; i(i, j)[0];
			float g = (float)inputImg.at(i, j)[1];
			float r = (float)inputImg.at(i, j)[2];

			for (int si = -NEIGHBORHOOD; si <= NEIGHBORHOOD; si++)
			{
				int ni = i + si;
				//防止数组越界
				if (ni < 0 || ni >= inputImg.rows)
					continue;

				for (int sj = 0; sj <= NEIGHBORHOOD; sj++)
				{
					int nj = j + sj;
					if (nj < 0 || nj >= inputImg.cols)
						continue;

					// 忽略相同点
					// down pointed edge, this edge will be counted as an up edge for the other pixel
					if (si >= 0 && sj == 0)
						continue;

					// 相当于在一个圆形区域内计算
					if ((si*si + sj*sj) > NEIGHBORHOOD*NEIGHBORHOOD)
						continue;

					// 邻域点的ID
					int nNodeId = (i + si) * inputImg.cols + (j + sj);

					float nb = (float)inputImg.at(i + si, j + sj)[0];
					float ng = (float)inputImg.at(i + si, j + sj)[1];
					float nr = (float)inputImg.at(i + si, j + sj)[2];

					//边界项权重
					float currEdgeStrength = exp(-((b - nb)*(b - nb) + (g - ng)*(g - ng) + (r - nr)*(r - nr)) / (2 * varianceSquared));
					//与当前点的距离
					float currDist = sqrt((float)si*(float)si + (float)sj*(float)sj);

					// 计算权重
					currEdgeStrength = ((float)EDGE_STRENGTH_WEIGHT * currEdgeStrength + (float)(1 - EDGE_STRENGTH_WEIGHT)) / currDist;
					int edgeCapacity = (int)ceil(INT32_CONST*currEdgeStrength + 0.5);
					myGraph->add_edge(currNodeId, nNodeId, edgeCapacity, edgeCapacity);

				}
			}
			// 加入当前节点与附加节点的
			int currBin = (int)binPerPixelImg.at(i, j);

			myGraph->add_edge(currNodeId, (int)(currBin + inputImg.rows * inputImg.cols), (int)ceil(INT32_CONST*bha_slope + 0.5), (int)ceil(INT32_CONST*bha_slope + 0.5));
		}

	}
	//ui.label_show->setRelatedVar(bgScribbleMask, bgScribbleMaskAll, fgScribbleMask, fgScribbleMaskAll, scribbleRadius, showImg);
	ui.label_show->setRelatedVar(bgScribbleMask, fgScribbleMask, scribbleRadius, showImg);
	return 0;
}

//获取每一个像素的bin index
void onecut::generateBinIndex(Mat& bin, Mat & inImg, int binschannel, int & numUsedBins)
{
	// 记录bin 是否被使用              
	vector occupiedBinNewIdx((int)pow((double)binschannel, (double)3), -1);

	int newBinIdx = 0;
	for (int i = 0; i(i, j)[0];
			float g = (float)inImg.at(i, j)[1];
			float r = (float)inImg.at(i, j)[2];

			//计算bin index
			int bin_index = (int)(floor(b / 256.0 *(float)binschannel) + (float)binschannel * floor(g / 256.0*(float)binschannel)
				+ (float)binschannel * (float)binschannel * floor(r / 256.0*(float)binschannel));

			// 若这个bin目前没有被使用
			if (occupiedBinNewIdx[bin_index] == -1)
			{
				// 记录下这个bin对应的index
				occupiedBinNewIdx[bin_index] = newBinIdx;
				newBinIdx++;
			}

			bin.at(i, j) = (float)occupiedBinNewIdx[bin_index];
		}

	double maxBin;
	minMaxLoc(bin, NULL, &maxBin);
	numUsedBins = (int)maxBin + 1;

	occupiedBinNewIdx.clear();
}

// 计算边权重
void onecut::generateEdgeVariance(Mat & inputImg, Mat & showEdgesImg, float & varianceSquared)
{
	varianceSquared = 0;
	int counter = 0;
	for (int i = 0; i(i, j)[0];
			float g = (float)inputImg.at(i, j)[1];
			float r = (float)inputImg.at(i, j)[2];
			for (int si = -NEIGHBORHOOD; si <= NEIGHBORHOOD && si + i < inputImg.rows && si + i >= 0; si++)
			{
				for (int sj = 0; sj <= NEIGHBORHOOD && sj + j < inputImg.cols; sj++)
				{
					if ((si == 0 && sj == 0) ||
						(si == 1 && sj == 0) ||
						(si == NEIGHBORHOOD && sj == 0))
						continue;

					float nb = (float)inputImg.at(i + si, j + sj)[0];
					float ng = (float)inputImg.at(i + si, j + sj)[1];
					float nr = (float)inputImg.at(i + si, j + sj)[2];

					varianceSquared += (b - nb)*(b - nb) + (g - ng)*(g - ng) + (r - nr)*(r - nr);
					counter++;
				}

			}
		}
	}
	//其每一个像素计算了一个variance  这是一个总和
	varianceSquared /= counter;

}

void onecut::showImage(Mat image){
	//Show Image
	ui.label_show->clear();
	ui.label_show->setPixmap(QPixmap::fromImage(cvMatToQImage(image)));
	ui.label_show->resize(ui.label_show->pixmap()->size());
}

QImage onecut::cvMatToQImage(Mat& src){

	Mat* tmp = new Mat;
	QImage qImage;
	if (src.channels() == 3)    // RGB image  
	{
		cvtColor(src, *tmp, CV_BGR2RGB);
		qImage = QImage((const unsigned char*)(tmp->data), tmp->cols, tmp->rows, src.cols*src.channels(), QImage::Format_RGB888);
	}
	else                     // gray image  
	{
		qImage = QImage((const uchar*)(src.data),
			src.cols, src.rows,
			src.cols*src.channels(),    //new add  
			QImage::Format_Indexed8);
	}
	return qImage;
}

void onecut::onSegImage(){

	if (rectLabel!= NULL)
		if (rectLabel->isVisible())
		{
			if (rectLabel->startPnt.isNull())
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->endPnt.isNull())
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->startPnt.x() < 0)
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->startPnt.y() < 0)
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->endPnt.x() < 0)
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->endPnt.y() < 0)
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->startPnt.x() > rectLabel->width())
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->startPnt.y() > rectLabel->height())
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->endPnt.x() > rectLabel->width())
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->endPnt.y() > rectLabel->height())
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->startPnt.x() == rectLabel->endPnt.x())
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			if (rectLabel->startPnt.y() == rectLabel->endPnt.y())
			{
				QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
				return;
			}
			for (int i = 0; i < inputImg.rows; i++)
				for (int j = 0; j < rectLabel->startPnt.x(); j++)
				{
					bgScribbleMask.at(i, j) = 255;
				}

			for (int i = 0; i < rectLabel->startPnt.y(); i++)
				for (int j = 0; j < inputImg.cols; j++)
				{
					bgScribbleMask.at(i, j) = 255;
				}
	
			for (int i = rectLabel->endPnt.y(); i < rectLabel->height(); i++)
				for (int j = 0; j < inputImg.cols; j++)
				{
					bgScribbleMask.at(i, j) = 255;
				}

			for (int i = 0; i < inputImg.rows; i++)
				for (int j = rectLabel->endPnt.x(); j < inputImg.cols; j++)
				{
					bgScribbleMask.at(i, j) = 255;
				}

			this->rectLabel->close();
			this->rectLabel = NULL;
		}


	//设置硬约束
	for (int i = 0; i(i, j) == 255)
				myGraph->add_tweights(currNodeId, (int)ceil(INT32_CONST * HARD_CONSTRAINT_CONST + 0.5), 0);
			else if (bgScribbleMask.at(i, j) == 255)
				myGraph->add_tweights(currNodeId, 0, (int)ceil(INT32_CONST * HARD_CONSTRAINT_CONST + 0.5));
		}
	}
	
	//执行最大流算法
	myGraph->maxflow();

	segMask = 0;
	inputImg.copyTo(segShowImg);
	fgScribbleMask = 0;
	bgScribbleMask = 0;

	for (int i = 0; iwhat_segment((int)i) == GraphType::SOURCE)
		{
			segMask.at(i / inputImg.cols, i%inputImg.cols) = 255;
		}
		else
		{
			segMask.at(i / inputImg.cols, i%inputImg.cols) = 0;			
			(uchar)segShowImg.at(i / inputImg.cols, i%inputImg.cols)[0] = 0;
			(uchar)segShowImg.at(i / inputImg.cols, i%inputImg.cols)[1] = 0;
			(uchar)segShowImg.at(i / inputImg.cols, i%inputImg.cols)[2] = 0;
		}
	}

	this->showImage(segShowImg);
	this->showImg = segShowImg.clone();
	//ui.label_show->setRelatedVar(bgScribbleMask, bgScribbleMaskAll, fgScribbleMask, fgScribbleMaskAll, scribbleRadius, showImg);
	ui.label_show->setRelatedVar(bgScribbleMask, fgScribbleMask, scribbleRadius, showImg);
	this->lastSegState = 1;
}

void onecut::onMouseMoveFinish(Mat bgScribbleMask,
	Mat fgScribbleMask,
	Mat showImg){
	if (this->lastSegState == 1)
	{
		fgScribbleMask_last = this->fgScribbleMask.clone();
		bgScribbleMask_last = this->bgScribbleMask.clone();
		showImg_last = this->showImg.clone();
		this->lastSegState = 0;
	}

	this->bgScribbleMask = bgScribbleMask.clone();
	this->fgScribbleMask = fgScribbleMask.clone();
	this->showImg = showImg.clone();
}

void onecut::onFinish(){
	emit okClicked(this->segShowImg);
	this->close();
}

void onecut::onCutImage(){
	if (imageLable == NULL)
	{
		imageLable = new cutImageLabel();
		imageLable->show();
		ui.gridLayout->addWidget(imageLable, 0, 0);
	}
}

void onecut::onRectSeg(){
	rectLabel = new rectSegLabel();
	rectLabel->resize(inputImg.cols,inputImg.rows);
	rectLabel->show();
	ui.gridLayout->addWidget(rectLabel, 0, 0);
}

void onecut::onConfirmCut(){

	if (imageLable == NULL)
		return;
	if (imageLable->startPnt.isNull())
		return;
	if (imageLable->endPnt.isNull())
		return;
	if (imageLable->startPnt.x() < 0)
		return;
	if (imageLable->startPnt.y() < 0)
		return;
	if (imageLable->endPnt.x() < 0)
		return;
	if (imageLable->endPnt.y() < 0)
		return;
	if (imageLable->startPnt.x() > imageLable->width())
		return;
	if (imageLable->startPnt.y() > imageLable->height())
		return;
	if (imageLable->endPnt.x() > imageLable->width())
		return;
	if (imageLable->endPnt.y() > imageLable->height())
		return;
	if (imageLable->startPnt.x() == imageLable->endPnt.x())
		return;
	if (imageLable->startPnt.y() == imageLable->endPnt.y())
		return;

	Point cvP1(imageLable->startPnt.x(), imageLable->startPnt.y());
	Point cvP2(imageLable->endPnt.x(), imageLable->endPnt.y());
		
	Rect rect(cvP1, cvP2);
	this->showImg = this->rectCutImage(showImg, rect);
	this->inputImg = this->showImg.clone();
	this->fgScribbleMask = 0;
	this->bgScribbleMask = 0;
	//ui.label_show->setRelatedVar(bgScribbleMask, bgScribbleMaskAll, fgScribbleMask, fgScribbleMaskAll, scribbleRadius, showImg);
	ui.label_show->setRelatedVar(bgScribbleMask, fgScribbleMask, scribbleRadius, showImg);
	this->showImage(showImg);
	
	imageLable->close();
	imageLable = NULL;
}

void onecut::onDrawImage(){
	if (imageLable != NULL)
		imageLable->close();
	if (rectLabel != NULL)
		rectLabel->close();
}

void onecut::keyPressEvent(QKeyEvent  *event){
	if ((event->modifiers() == Qt::ControlModifier) && (event->key() == Qt::Key_Z))
	{
		this->fgScribbleMask = fgScribbleMask_last.clone();
		this->bgScribbleMask = bgScribbleMask_last.clone();
		this->showImg = showImg_last.clone();
		onSegImage();
	}

	if ((event->modifiers() == Qt::ControlModifier) && (event->key() == Qt::Key_R))
	{
		init(inputImg);
	}
}

void onecut::onLineWidthChanged(int value){
	scribbleRadius = value;
	ui.label_show->setRadius(value);

}

Mat onecut::rectCutImage(const Mat& src, Rect rect){
	Mat tmp = Mat::zeros(rect.size(), src.type());
	for (int i = 0; i < tmp.rows; ++i)
		for (int j = 0; j < tmp.cols; ++j)
			for (int k = 0; k < 3; ++k)
				tmp.at(i, j)[k] = saturate_cast(src.at(rect.y + i, rect.x + j)[k]);
	return tmp;
}


你可能感兴趣的:(grabcut in one-cut 一种好用快速的图像分割算法)