2013年iCCV上的这篇论文,提出了一种快速的基于简单交互的分割算法,本篇博文是对该论文的解读。
Tang M, Gorelick L, Veksler O, et al. GrabCut inOne Cut[C]// IEEE International Conference on Computer Vision. IEEE ComputerSociety, 2013:1769-1776.
grabcut in onecut 基于传统的graph cut分割方法,这是一种非常流行地能量优化算法,这类方法把图像分割问题与凸的最小割问题相关联,首先用一个无向图G=
第一种顶点和边是:第一种普通顶点对应于图像中的每个像素。每两个邻域顶点(对应于图像中每两个邻域像素)的连接就是一条边。这种边也叫n-links。
第二种顶点和边是:每个普通顶点和这2个终端顶点之间都有连接,组成第二种边。这种边也叫t-links。
图1 s-t图
上图就是一个图像对应的S-T图,每个像素对应于图中的一个相应顶点,在这些顶点之外还有source顶点与sink顶点。蓝色和红色的边表示t-links,黄色的边表示n-links。在图像分割中,s顶点一般表示为前景目标,t顶点一般表示为背景目标。
每条边都有权重,graph cut 中的cut是指图中的边集合的一个子集,cut中的所有边的权重和被叫做cost(代价)。
Graph Cuts中的Cuts是指这样一个边的集合,很显然这些边集合包括了上面2种边,该集合中所有边的断开会导致残留”S”和”T”图的分开,所以就称为“割”。如果一个割,它的边的所有权值之和最小,那么这个就称为最小割,也就是图割的结果。而福特-富克森定理表明,网路的最大流max flow与最小割min cut相等。所以由Boykov和Kolmogorov发明的max-flow/min-cut算法就可以用来获得s-t图的最小割。这个最小割把图的顶点划分为两个不相交的子集S和T,其中s ∈S,t∈ T和S∪T=V 。这两个子集就对应于图像的前景像素集和背景像素集,那就相当于完成了图像分割。
图2 s-t最小割示意图
边的权重的确立,遵循这样一种原则,前景与背景的分界处的权值应当最小,最小化图割是用最小化能量函数得到。
公式中,L表示图割,R(L)为区域项,B(L)为边界项,a是权重因子,表示区域项与边界项所占的比例差别。区域项往往由下面的公式表示
该能量项表示为像素P分配标签 的惩罚, 表示为像素p分配标签 的惩罚,该能量项的值往往通过对比像素p的灰度与给定的目标的灰度直方图来获得。
边界项
由于边界两侧点的像素值差别往往比较大,因此,边界项的作用就是当两邻域像素的差别很大时,边界项的值应当最小。
在grabcut in one cut论文中,作者将区域项替换成了下式,有效的避免NP-hard问题
上式可以转换成下面的式子
Ω表示的是在一个bin中像素的数量, 表示的是bin中属于前景的像素数量, 表示的是属于背景的像素数量。
图3 附加顶点示意图
同时作者给出了实现方法,通过增加辅助节点 ,k表示第k个bin将辅助节点与相应bin中的值相连,同时设置权值为1,这样在最小化前景背景时,当边界项的权值最小,同时也会相应的最小化辅助节点与相应边的连接权重,也就是最小化了区域项。
http://blog.csdn.net/zouxy09/article/details/8532111
http://vision.csd.uwo.ca/wiki/vision/upload/7/77/OneCutWithSeeds_v1.03.zip
http://download.csdn.net/download/zhangyumengs/10237656
下面给出本人改写的部分qt代码,加入了注释,方便阅读。
同时加入了新的功能,可以进行裁剪图像,方形分割。
onecut.h
#ifndef ONECUT_H
#define ONECUT_H
#include
#include "ui_onecut.h"
#include
#include
#include
#include
#include
#include
#include
#include "graph.h"
#include "qmessagebox.h"
//自定义控件
#include "onecutlabel.h"
#include "cutimagelabel.h"
#include "rectseglabel.h"
using namespace std;
using namespace cv;
#define NEIGHBORHOOD_4_TYPE 1;
const int NEIGHBORHOOD = NEIGHBORHOOD_4_TYPE;
class onecut : public QWidget
{
Q_OBJECT
public:
onecut(QWidget *parent = 0);
~onecut();
//___________________________________________________________________________________________________
Mat inputImg, showImg, binPerPixelImg, showEdgesImg, segMask, segShowImg;
Mat fgScribbleMask, bgScribbleMask;
//用于撤销上一次操作
Mat fgScribbleMask_last, bgScribbleMask_last;
int lastSegState = 1;
Mat showImg_last;
int numUsedBins = 0;
float varianceSquared = 0;
int scribbleRadius = 10;
float bha_slope = 0.5f;
int numBinsPerChannel = 16;
float EDGE_STRENGTH_WEIGHT = 0.95f;
const float INT32_CONST = 1000;
const float HARD_CONSTRAINT_CONST = 1000;
int init(Mat src);
void destroyAll();
// 为每一个像素计算index
void generateBinIndex(Mat& bin, Mat& inImg, int binschannel, int& binsNotEmpty);
//计算高斯分布的方差
void generateEdgeVariance(Mat& inputImg, Mat& showEdgesImg, float& varianceSquared);
typedef Graph GraphType;
GraphType *myGraph;
private slots:
void onSegImage();
void onMouseMoveFinish(Mat bgScribbleMask,
Mat fgScribbleMask,
Mat showImg);
void onFinish();
void onCutImage();
void onRectSeg();
void onConfirmCut();
void onDrawImage();
void onLineWidthChanged(int);
signals:
void okClicked(Mat result);
private:
Ui::onecut ui;
//裁剪相关变量
cutImageLabel* imageLable = NULL;
rectSegLabel* rectLabel = NULL;
Mat rectCutImage(const Mat& src, Rect rect);
void showImage(Mat image);
QImage cvMatToQImage(Mat& src);
protected:
void keyPressEvent(QKeyEvent *event);
};
#endif
onecut.cpp
#include "onecut.h"
onecut::onecut(QWidget *parent)
: QWidget(parent)
{
ui.setupUi(this);
connect(ui.label_show, SIGNAL(mouseMoveFinish(Mat, Mat, Mat)),
this, SLOT(onMouseMoveFinish(Mat ,Mat ,Mat)));
connect(ui.button_seg, SIGNAL(clicked()),
this, SLOT(onSegImage()));
connect(ui.button_ok, SIGNAL(clicked()),
this, SLOT(onFinish()));
//裁剪涂抹相关
connect(ui.button_cut, SIGNAL(clicked()),
this, SLOT(onCutImage()));
connect(ui.button_confirmcut, SIGNAL(clicked()),
this, SLOT(onConfirmCut()));
connect(ui.button_draw, SIGNAL(clicked()),
this, SLOT(onDrawImage()));
connect(ui.button_rectseg, SIGNAL(clicked()),
this, SLOT(onRectSeg()));
//改变绘制线宽度
ui.slider_linewidth->setRange(2, 10);
ui.slider_linewidth->setValue(10);
connect(ui.slider_linewidth, SIGNAL(valueChanged(int)),
this, SLOT(onLineWidthChanged(int)));
}
onecut::~onecut()
{
}
void onecut::destroyAll()
{
// clear all data
fgScribbleMask.release();
bgScribbleMask.release();
inputImg.release();
showImg.release();
showEdgesImg.release();
binPerPixelImg.release();
segMask.release();
segShowImg.release();
delete myGraph;
}
int onecut::init(Mat src)
{
// 初始化Mat
inputImg = src.clone();
this->showImage(inputImg);
showImg = inputImg.clone();
segShowImg = inputImg.clone();
// 检查输入的合理性
if (!inputImg.data)
{
return -1;
}
// 初始化涂鸦
fgScribbleMask.create(2, inputImg.size, CV_8UC1);
fgScribbleMask = 0;
bgScribbleMask.create(2, inputImg.size, CV_8UC1);
bgScribbleMask = 0;
segMask.create(2, inputImg.size, CV_8UC1);
segMask = 0;
showEdgesImg.create(2, inputImg.size, CV_32FC1);
showEdgesImg = 0;
binPerPixelImg.create(2, inputImg.size, CV_32F);
// numBinsPerChannel = 16 numUsedBins 表示的是一共有多少个bin的频数不为0
generateBinIndex(binPerPixelImg, inputImg, numBinsPerChannel, numUsedBins);
// 计算方差
generateEdgeVariance(inputImg, showEdgesImg, varianceSquared);
//相当于添加了一些辅助节点
myGraph = new GraphType( inputImg.rows * inputImg.cols + numUsedBins,12 * inputImg.rows * inputImg.cols);
int currNodeId = myGraph->add_node((int)inputImg.cols * inputImg.rows + numUsedBins);
for (int i = 0; i(i, j)[0];
float g = (float)inputImg.at(i, j)[1];
float r = (float)inputImg.at(i, j)[2];
for (int si = -NEIGHBORHOOD; si <= NEIGHBORHOOD; si++)
{
int ni = i + si;
//防止数组越界
if (ni < 0 || ni >= inputImg.rows)
continue;
for (int sj = 0; sj <= NEIGHBORHOOD; sj++)
{
int nj = j + sj;
if (nj < 0 || nj >= inputImg.cols)
continue;
// 忽略相同点
// down pointed edge, this edge will be counted as an up edge for the other pixel
if (si >= 0 && sj == 0)
continue;
// 相当于在一个圆形区域内计算
if ((si*si + sj*sj) > NEIGHBORHOOD*NEIGHBORHOOD)
continue;
// 邻域点的ID
int nNodeId = (i + si) * inputImg.cols + (j + sj);
float nb = (float)inputImg.at(i + si, j + sj)[0];
float ng = (float)inputImg.at(i + si, j + sj)[1];
float nr = (float)inputImg.at(i + si, j + sj)[2];
//边界项权重
float currEdgeStrength = exp(-((b - nb)*(b - nb) + (g - ng)*(g - ng) + (r - nr)*(r - nr)) / (2 * varianceSquared));
//与当前点的距离
float currDist = sqrt((float)si*(float)si + (float)sj*(float)sj);
// 计算权重
currEdgeStrength = ((float)EDGE_STRENGTH_WEIGHT * currEdgeStrength + (float)(1 - EDGE_STRENGTH_WEIGHT)) / currDist;
int edgeCapacity = (int)ceil(INT32_CONST*currEdgeStrength + 0.5);
myGraph->add_edge(currNodeId, nNodeId, edgeCapacity, edgeCapacity);
}
}
// 加入当前节点与附加节点的
int currBin = (int)binPerPixelImg.at(i, j);
myGraph->add_edge(currNodeId, (int)(currBin + inputImg.rows * inputImg.cols), (int)ceil(INT32_CONST*bha_slope + 0.5), (int)ceil(INT32_CONST*bha_slope + 0.5));
}
}
//ui.label_show->setRelatedVar(bgScribbleMask, bgScribbleMaskAll, fgScribbleMask, fgScribbleMaskAll, scribbleRadius, showImg);
ui.label_show->setRelatedVar(bgScribbleMask, fgScribbleMask, scribbleRadius, showImg);
return 0;
}
//获取每一个像素的bin index
void onecut::generateBinIndex(Mat& bin, Mat & inImg, int binschannel, int & numUsedBins)
{
// 记录bin 是否被使用
vector occupiedBinNewIdx((int)pow((double)binschannel, (double)3), -1);
int newBinIdx = 0;
for (int i = 0; i(i, j)[0];
float g = (float)inImg.at(i, j)[1];
float r = (float)inImg.at(i, j)[2];
//计算bin index
int bin_index = (int)(floor(b / 256.0 *(float)binschannel) + (float)binschannel * floor(g / 256.0*(float)binschannel)
+ (float)binschannel * (float)binschannel * floor(r / 256.0*(float)binschannel));
// 若这个bin目前没有被使用
if (occupiedBinNewIdx[bin_index] == -1)
{
// 记录下这个bin对应的index
occupiedBinNewIdx[bin_index] = newBinIdx;
newBinIdx++;
}
bin.at(i, j) = (float)occupiedBinNewIdx[bin_index];
}
double maxBin;
minMaxLoc(bin, NULL, &maxBin);
numUsedBins = (int)maxBin + 1;
occupiedBinNewIdx.clear();
}
// 计算边权重
void onecut::generateEdgeVariance(Mat & inputImg, Mat & showEdgesImg, float & varianceSquared)
{
varianceSquared = 0;
int counter = 0;
for (int i = 0; i(i, j)[0];
float g = (float)inputImg.at(i, j)[1];
float r = (float)inputImg.at(i, j)[2];
for (int si = -NEIGHBORHOOD; si <= NEIGHBORHOOD && si + i < inputImg.rows && si + i >= 0; si++)
{
for (int sj = 0; sj <= NEIGHBORHOOD && sj + j < inputImg.cols; sj++)
{
if ((si == 0 && sj == 0) ||
(si == 1 && sj == 0) ||
(si == NEIGHBORHOOD && sj == 0))
continue;
float nb = (float)inputImg.at(i + si, j + sj)[0];
float ng = (float)inputImg.at(i + si, j + sj)[1];
float nr = (float)inputImg.at(i + si, j + sj)[2];
varianceSquared += (b - nb)*(b - nb) + (g - ng)*(g - ng) + (r - nr)*(r - nr);
counter++;
}
}
}
}
//其每一个像素计算了一个variance 这是一个总和
varianceSquared /= counter;
}
void onecut::showImage(Mat image){
//Show Image
ui.label_show->clear();
ui.label_show->setPixmap(QPixmap::fromImage(cvMatToQImage(image)));
ui.label_show->resize(ui.label_show->pixmap()->size());
}
QImage onecut::cvMatToQImage(Mat& src){
Mat* tmp = new Mat;
QImage qImage;
if (src.channels() == 3) // RGB image
{
cvtColor(src, *tmp, CV_BGR2RGB);
qImage = QImage((const unsigned char*)(tmp->data), tmp->cols, tmp->rows, src.cols*src.channels(), QImage::Format_RGB888);
}
else // gray image
{
qImage = QImage((const uchar*)(src.data),
src.cols, src.rows,
src.cols*src.channels(), //new add
QImage::Format_Indexed8);
}
return qImage;
}
void onecut::onSegImage(){
if (rectLabel!= NULL)
if (rectLabel->isVisible())
{
if (rectLabel->startPnt.isNull())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.isNull())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.x() < 0)
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.y() < 0)
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.x() < 0)
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.y() < 0)
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.x() > rectLabel->width())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.y() > rectLabel->height())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.x() > rectLabel->width())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->endPnt.y() > rectLabel->height())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.x() == rectLabel->endPnt.x())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
if (rectLabel->startPnt.y() == rectLabel->endPnt.y())
{
QMessageBox::information(NULL, QString("Information"), QString("draw error, please draw again!"));
return;
}
for (int i = 0; i < inputImg.rows; i++)
for (int j = 0; j < rectLabel->startPnt.x(); j++)
{
bgScribbleMask.at(i, j) = 255;
}
for (int i = 0; i < rectLabel->startPnt.y(); i++)
for (int j = 0; j < inputImg.cols; j++)
{
bgScribbleMask.at(i, j) = 255;
}
for (int i = rectLabel->endPnt.y(); i < rectLabel->height(); i++)
for (int j = 0; j < inputImg.cols; j++)
{
bgScribbleMask.at(i, j) = 255;
}
for (int i = 0; i < inputImg.rows; i++)
for (int j = rectLabel->endPnt.x(); j < inputImg.cols; j++)
{
bgScribbleMask.at(i, j) = 255;
}
this->rectLabel->close();
this->rectLabel = NULL;
}
//设置硬约束
for (int i = 0; i(i, j) == 255)
myGraph->add_tweights(currNodeId, (int)ceil(INT32_CONST * HARD_CONSTRAINT_CONST + 0.5), 0);
else if (bgScribbleMask.at(i, j) == 255)
myGraph->add_tweights(currNodeId, 0, (int)ceil(INT32_CONST * HARD_CONSTRAINT_CONST + 0.5));
}
}
//执行最大流算法
myGraph->maxflow();
segMask = 0;
inputImg.copyTo(segShowImg);
fgScribbleMask = 0;
bgScribbleMask = 0;
for (int i = 0; iwhat_segment((int)i) == GraphType::SOURCE)
{
segMask.at(i / inputImg.cols, i%inputImg.cols) = 255;
}
else
{
segMask.at(i / inputImg.cols, i%inputImg.cols) = 0;
(uchar)segShowImg.at(i / inputImg.cols, i%inputImg.cols)[0] = 0;
(uchar)segShowImg.at(i / inputImg.cols, i%inputImg.cols)[1] = 0;
(uchar)segShowImg.at(i / inputImg.cols, i%inputImg.cols)[2] = 0;
}
}
this->showImage(segShowImg);
this->showImg = segShowImg.clone();
//ui.label_show->setRelatedVar(bgScribbleMask, bgScribbleMaskAll, fgScribbleMask, fgScribbleMaskAll, scribbleRadius, showImg);
ui.label_show->setRelatedVar(bgScribbleMask, fgScribbleMask, scribbleRadius, showImg);
this->lastSegState = 1;
}
void onecut::onMouseMoveFinish(Mat bgScribbleMask,
Mat fgScribbleMask,
Mat showImg){
if (this->lastSegState == 1)
{
fgScribbleMask_last = this->fgScribbleMask.clone();
bgScribbleMask_last = this->bgScribbleMask.clone();
showImg_last = this->showImg.clone();
this->lastSegState = 0;
}
this->bgScribbleMask = bgScribbleMask.clone();
this->fgScribbleMask = fgScribbleMask.clone();
this->showImg = showImg.clone();
}
void onecut::onFinish(){
emit okClicked(this->segShowImg);
this->close();
}
void onecut::onCutImage(){
if (imageLable == NULL)
{
imageLable = new cutImageLabel();
imageLable->show();
ui.gridLayout->addWidget(imageLable, 0, 0);
}
}
void onecut::onRectSeg(){
rectLabel = new rectSegLabel();
rectLabel->resize(inputImg.cols,inputImg.rows);
rectLabel->show();
ui.gridLayout->addWidget(rectLabel, 0, 0);
}
void onecut::onConfirmCut(){
if (imageLable == NULL)
return;
if (imageLable->startPnt.isNull())
return;
if (imageLable->endPnt.isNull())
return;
if (imageLable->startPnt.x() < 0)
return;
if (imageLable->startPnt.y() < 0)
return;
if (imageLable->endPnt.x() < 0)
return;
if (imageLable->endPnt.y() < 0)
return;
if (imageLable->startPnt.x() > imageLable->width())
return;
if (imageLable->startPnt.y() > imageLable->height())
return;
if (imageLable->endPnt.x() > imageLable->width())
return;
if (imageLable->endPnt.y() > imageLable->height())
return;
if (imageLable->startPnt.x() == imageLable->endPnt.x())
return;
if (imageLable->startPnt.y() == imageLable->endPnt.y())
return;
Point cvP1(imageLable->startPnt.x(), imageLable->startPnt.y());
Point cvP2(imageLable->endPnt.x(), imageLable->endPnt.y());
Rect rect(cvP1, cvP2);
this->showImg = this->rectCutImage(showImg, rect);
this->inputImg = this->showImg.clone();
this->fgScribbleMask = 0;
this->bgScribbleMask = 0;
//ui.label_show->setRelatedVar(bgScribbleMask, bgScribbleMaskAll, fgScribbleMask, fgScribbleMaskAll, scribbleRadius, showImg);
ui.label_show->setRelatedVar(bgScribbleMask, fgScribbleMask, scribbleRadius, showImg);
this->showImage(showImg);
imageLable->close();
imageLable = NULL;
}
void onecut::onDrawImage(){
if (imageLable != NULL)
imageLable->close();
if (rectLabel != NULL)
rectLabel->close();
}
void onecut::keyPressEvent(QKeyEvent *event){
if ((event->modifiers() == Qt::ControlModifier) && (event->key() == Qt::Key_Z))
{
this->fgScribbleMask = fgScribbleMask_last.clone();
this->bgScribbleMask = bgScribbleMask_last.clone();
this->showImg = showImg_last.clone();
onSegImage();
}
if ((event->modifiers() == Qt::ControlModifier) && (event->key() == Qt::Key_R))
{
init(inputImg);
}
}
void onecut::onLineWidthChanged(int value){
scribbleRadius = value;
ui.label_show->setRadius(value);
}
Mat onecut::rectCutImage(const Mat& src, Rect rect){
Mat tmp = Mat::zeros(rect.size(), src.type());
for (int i = 0; i < tmp.rows; ++i)
for (int j = 0; j < tmp.cols; ++j)
for (int k = 0; k < 3; ++k)
tmp.at(i, j)[k] = saturate_cast(src.at(rect.y + i, rect.x + j)[k]);
return tmp;
}