PS中磁性套索工具实现算法

原文:http://www.cnblogs.com/frombeijingwithlove/p/3719116.html

Photoshop中磁力套索的一种简陋实现(基于Python)

 

经常用Photoshop的人应该熟悉磁力套索(Magnetic Lasso)这个功能,就是人为引导下的抠图辅助工具。在研发领域一般不这么叫,通常管这种边缘提取的办法叫Intelligent Scissors或者Livewire。

本来是给一个图像分割项目算法评估时的Python框架,觉得有点意思,就稍稍拓展了一下,用PyQt加了个壳,在非常简陋的程度上模拟了一下的磁力套索功能。为什么简陋:1) 只实现了最基本的边缘查找。路径冷却,动态训练,鼠标位置修正都没有,更别提曲线闭合,抠图,Alpha Matting等等;2) 没考虑性能规范,只为了书写方便;3) 我对Qt了解很浅,至今不会写Signal-Slot,不知道GUI写得是否合理;4) 没调试。

PS中磁性套索工具实现算法_第1张图片

基本算法

相关算法我并没有做很深入的调研,不过相信这类应用中影响力最大的算法是来源于[1],也是本文的主要参考,基本思想是把图片看成是一个无向图,相邻像素之间就可以计算出一个局部cost,于是就转化成了最短路径问题了,接下来就是基于Dijkstra算法产生路径,就是需要提取的边缘。主要涉及的算法有两部分:1) 相邻像素的cost计算;2) 最短路径算法。

边缘检测

计算相邻像素cost的最终目的还是为了寻找边缘,所以本质还是边缘检测。基本思想是,通过各种不同手段检测边缘,并且根据检测到的强度来求加权值,作为cost。从最短路径的角度来说,就是边缘越明显的地方,cost的值越小。[1]中的建议是用三种指标求加权:1) 边缘检测算子;2) 梯度强度(Gradient Magnitude);3) 梯度方向(Gradient Direction)。本文的方法和[1]有那么一些不一样,因为懒,用了OpenCV中的Canny算子检测边缘而不是Laplacian Zero-CrossingOperator。表达式如下:

l(p,q)=wEfE(q)+wGfG(q)+wDfD(p,q)

Canny算子

基本思想是根据梯度信息,先检测出许多连通的像素,然后对于每一坨连通的像素只取其中最大值且连通的部分,将周围置零,得到初始的边缘(Edges),这个过程叫做Non-Maximum Suppression。然后用二阈值的办法将这些检测到的初始边缘分为Strong, Weak, and None三个等级,顾名思义,Strong就是很确定一定是边缘了,None就被舍弃,然后从Weak中挑选和Strong连通的作为保留的边缘,得到最后的结果,这个过程叫做Hysteresis Thresholding。这个算法太经典,更多细节一Google出来一大堆,我就不赘述了。公式如下:

fE(q)={ 0; if q is on a edge   1; if q is not on a edge

其实从权值的计算上和最大梯度有些重复,因为如果沿着最大梯度方向找出来的路径基本上就是边缘,这一项的作用我的理解主要应该是1)避免梯度都很大的区域出现离明显边缘的偏离;2) 保证提取边缘的连续性,一定程度上来讲也是保证平滑。

梯度强度

就是梯度求模而已,x和y两个方向的梯度值平方相加在开方,公式如下:

IG(q)=Ix(q)+Iy(q)−−−−−−−−−−

因为要求cost,所以反向并归一化:

fG(q)=1IG(q)max(IG)

梯度方向

这一项其实是个平滑项,会给变化剧烈的边缘赋一个比较高的cost,让提取的边缘避免噪声的影响。具体公式如下:

fD(p,q)=23π(arccos(dp(p,q))+arccos(dq(p,q)))

其中,

dp(p,q)=d⊥(p),lD(p,q)

dq(p,q)=lD(p,q),d⊥(q)

lD(p,q)={ qp; ifd⊥(p),qp≥0   pq;if d⊥(p),qp<0

d⊥(p)是取p的垂直方向,另外注意上式中符号的判断会将d⊥(p)和lD(p,q)的取值限制在π/2以内。

d⊥(p)=(py,px)

斜对角方向的cost修正

在二维图像中,相邻的像素通常按照间隔欧式距离分为两种:1) 上下左右相邻,间隔为像素边长;2) 斜对角相邻,间隔为像素边长的2√倍。在计算局部cost的时候通常要把这种距离差异的影响考虑进去,比如下面这幅图:

PS中磁性套索工具实现算法_第2张图片

如果不考虑像素位置的影响,那么查找最小cost的时候会认为左上角的cost=2最小。然而如果考虑到像素间距的影响,我们来看左上角方向,和中心的差异是6-2,做个线性插值的话,则左上角距中心单位距离上的值应该为6(62)×1/2√ =3.17>3,所以正上方的才是最小cost的正确方向。

最短路径查找

在磁力套索中,一般的用法是先单击一个点,然后移动鼠标,在鼠标和一开始单击的点之间就会出现自动贴近边缘的线,这里我们定义一开始单击的像素点为种子点(seed),而磁力套索其实在考虑上部分提到的边缘相关cost的情况下查找种子点到当前鼠标的最短路径。如下图,红色的就是种子点,而移动鼠标时,最贴近边缘的种子点和鼠标坐标的连线就会实时显示,这也是为什么磁力套索也叫Livewire。

PS中磁性套索工具实现算法_第3张图片

实现最短路径的办法很多,一般而言就是动态规划了,这里介绍的是基于Dijkstra算法的一种实现,基本思想是,给定种子点后,执行Dijkstra算法将图像的所有像素遍历,得到每个像素到种子点的最短路径。以下面这幅图为例,在一个cost矩阵中,利用Dijkstra算法遍历每一个元素后,每个元素都会指向一个相邻的元素,这样任意一个像素都能找到一条到seed的路径,比如右上角的4239对应的像素,沿着箭头到了0

PS中磁性套索工具实现算法_第4张图片

算法如下:

输入:  

  s             // 种子点  

  l(q,r)        // 计算局部cost  

        

数据结构:  

  L            // 当前待处理的像素  

  N(q)         // 当前像素相邻的像素  

  e(q)         // 标记一个像素是否已经做过相邻像素展开的Bool函数  

  g(q)         // 从s到q的总cost  

        

输出:  

  p            // 记录所有路径的map  

        

算法:  

  g(s)←0; L←s;                // 将种子点作为第一点初始化  

  whileL≠Ø:                  // 遍历尚未结束  

    q←min(L);                 // 取出最小cost的像素并从待处理像素中移除  

    e(q)←TRUE;                // 将当前像素记录为已经做过相邻像素展开  

    foreach r∈N(q) and not e(r):  

      gtemp←g(q)+l(q,r);       // 计算相邻像素的总cost  

      ifr∈L and gtemp

        r←L; { fromlist.}     // 舍弃较大cost的路径  

      elseif r∉L:  

        g(r)←gtemp;            // 记录当前找到的最小路径  

        p(r)←q;  

        L←r;                   // 加入待处理以试图寻找更短的路径

 

遍历的过程会优先经过cost最低的区域,如下图:

PS中磁性套索工具实现算法_第5张图片

所有像素对应的到种子像素的最短路径都找到后,移动鼠标时就直接画出到seed的最短路径就可以了。

Python实现

算法部分直接调用了OpenCV的Canny函数和Sobel函数(求梯度),对于RGB的处理也很简陋,直接用梯度最大的值来近似。另外因为懒,cost map和path map都直接用了字典(dict),而记录展开过的像素则直接采用了集合(set)。GUI部分因为不会用QThread所以用了Python的threading,只有图像显示交互区域和状态栏提示,左键点击设置种子点,右键结束,已经提取的边缘为绿色线,正在提取的为蓝色线。

PS中磁性套索工具实现算法_第6张图片

from__future__ import division  
importcv2  
importnumpy as np  
        
SQRT_0_5= 0.70710678118654757  
        
classLivewire():  
    """  
    Asimple livewire implementation for verification using
        1.Canny edge detector + gradient magnitude + gradient direction  
        2.Dijkstra algorithm  
    """  
            
    def__init__(self, image):  
        self.image= image  
        self.x_lim= image.shape[0]  
        self.y_lim= image.shape[1]  
        #The values in cost matrix ranges from 0~1  
        self.cost_edges= 1 - cv2.Canny(image, 85, 170)/255.0  
        self.grad_x,self.grad_y, self.grad_mag = self._get_grad(image)  
        self.cost_grad_mag= 1 - self.grad_mag/np.max(self.grad_mag)  
        #Weight for (Canny edges, gradient magnitude, gradient direction)  
        self.weight= (0.425, 0.425, 0.15)  
                
        self.n_pixs= self.x_lim * self.y_lim  
        self.n_processed= 0  
            
    @classmethod  
    def_get_grad(cls, image):  
        """  
        Returnthe gradient magnitude of the image using Sobel operator
        """  
        rgb= True if len(image.shape) > 2 else False  
        grad_x= cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)  
        grad_y= cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)  
        ifrgb:  
            #A very rough approximation for quick verification...  
            grad_x= np.max(grad_x, axis=2)  
            grad_y= np.max(grad_y, axis=2)  
                    
        grad_mag= np.sqrt(grad_x**2+grad_y**2)  
        grad_x/= grad_mag  
        grad_y/= grad_mag  
                
        returngrad_x, grad_y, grad_mag  
            
    def_get_neighbors(self, p):  
        """  
        Return8 neighbors around the pixel p  
        """  
        x,y = p  
        x0= 0 if x == 0 else x - 1  
        x1= self.x_lim if x == self.x_lim - 1 else x + 2  
        y0= 0 if y == 0 else y - 1  
        y1= self.y_lim if y == self.y_lim - 1 else y + 2  
                
        return[(x, y) for x in xrange(x0, x1) for y in xrange(y0, y1) if (x, y) !=p]  
            
    def_get_grad_direction_cost(self, p, q):  
        """  
        Calculatethe gradient changes refer to the link direction  
        """  
        dp= (self.grad_y[p[0]][p[1]], -self.grad_x[p[0]][p[1]])  
        dq= (self.grad_y[q[0]][q[1]], -self.grad_x[q[0]][q[1]])  
                
        l= np.array([q[0]-p[0], q[1]-p[1]], np.float)  
        if0 not in l:  
            l*= SQRT_0_5  
                
        dp_l= np.dot(dp, l)  
        l_dq= np.dot(l, dq)  
        ifdp_l < 0:  
            dp_l= -dp_l  
            l_dq= -l_dq  
                
        #2/3pi * ...  
        return0.212206590789 * (np.arccos(dp_l)+np.arccos(l_dq))  
            
    def_local_cost(self, p, q):  
        """  
        1.Calculate the Canny edges & gradient magnitude cost taking into accountEuclidean distance  
        2.Combine with gradient direction  
        Assumption:p & q are neighbors  
        """  
        diagnol= q[0] == p[0] or q[1] == p[1]  
                
        #c0, c1 and c2 are costs from Canny operator, gradient magnitude and gradientdirection respectively  
        ifdiagnol:  
            c0=self.cost_edges[p[0]][p[1]]-SQRT_0_5*(self.cost_edges[p[0]][p[1]]-self.cost_edges[q[0]][q[1]])  
            c1=self.cost_grad_mag[p[0]][p[1]]-SQRT_0_5*(self.cost_grad_mag[p[0]][p[1]]-self.cost_grad_mag[q[0]][q[1]])  
            c2= SQRT_0_5 * self._get_grad_direction_cost(p, q)  
        else:  
            c0= self.cost_edges[q[0]][q[1]]  
            c1= self.cost_grad_mag[q[0]][q[1]]  
            c2= self._get_grad_direction_cost(p, q)  
                
        ifnp.isnan(c2):  
            c2= 0.0  
                
        w0,w1, w2 = self.weight  
        cost_pq= w0*c0 + w1*c1 + w2*c2  
                
        returncost_pq * cost_pq  
        
    defget_path_matrix(self, seed):  
        """  
        Getthe back tracking matrix of the whole image from the cost matrix  
        """  
        neighbors= []          # 8 neighbors of thepixel being processed  
        processed= set()       # Processed point  
        cost= {seed: 0.0}      # Accumulated cost, initializedwith seed to itself  
        paths= {}  
        
        self.n_processed= 0  
                
        whilecost:  
            #Expand the minimum cost point  
            p= min(cost, key=cost.get)  
            neighbors= self._get_neighbors(p)  
            processed.add(p)  
        
            #Record accumulated costs and back tracking point for newly expanded points  
            forq in [x for x in neighbors if x not in processed]:  
                temp_cost= cost[p] + self._local_cost(p, q)  
                ifq in cost:  
                    iftemp_cost < cost[q]:  
                        cost.pop(q)  
                else:  
                    cost[q]= temp_cost  
                    processed.add(q)  
                    paths[q]= p  
                    
            #Pop traversed points  
            cost.pop(p)  
                    
            self.n_processed+= 1  
                
        returnpaths  
        
livewire.py
from__future__ import division  
importtime  
importcv2  
fromPyQt4 import QtGui, QtCore  
fromthreading import Thread  
fromlivewire import Livewire  
        
classImageWin(QtGui.QWidget):  
    def__init__(self):  
        super(ImageWin,self).__init__()  
        self.setupUi()  
        self.active= False
        self.seed_enabled= True
        self.seed= None
        self.path_map= {}  
        self.path= []  
                
    defsetupUi(self):  
        self.hbox= QtGui.QVBoxLayout(self)  
                
        #Load and initialize image  
        self.image_path= ''  
        whileself.image_path == '':  
            self.image_path= QtGui.QFileDialog.getOpenFileName(self, '', '', '(*.bmp *.jpg*.png)')  
        self.image= QtGui.QPixmap(self.image_path)  
        self.cv2_image= cv2.imread(str(self.image_path))  
        self.lw= Livewire(self.cv2_image)  
        self.w,self.h = self.image.width(), self.image.height()  
                
        self.canvas= QtGui.QLabel(self)  
        self.canvas.setMouseTracking(True)  
        self.canvas.setPixmap(self.image)  
                
        self.status_bar= QtGui.QStatusBar(self)  
        self.status_bar.showMessage('Leftclick to set a seed')  
                
        self.hbox.addWidget(self.canvas)  
        self.hbox.addWidget(self.status_bar)  
        self.setLayout(self.hbox)  
            
    defmousePressEvent(self,event):              
        ifself.seed_enabled:  
            pos= event.pos()  
            x,y = pos.x()-self.canvas.x(), pos.y()-self.canvas.y()  
                    
            ifx < 0:  
                x= 0
            ifx >= self.w:  
                x= self.w - 1
            ify < 0:  
                y= 0
            ify >= self.h:  
                y= self.h - 1
        
            #Get the mouse cursor position  
            p= y, x  
            seed= self.seed  
                    
            #Export bitmap  
            ifevent.buttons() == QtCore.Qt.MidButton:  
                filepath= QtGui.QFileDialog.getSaveFileName(self, 'Save image audio to', '', '*.bmpn*.jpgn*.png')  
                image= self.image.copy()  
                        
                draw= QtGui.QPainter()  
                draw.begin(image)  
                draw.setPen(QtCore.Qt.blue)  
                ifself.path_map:  
                    whilep != seed:  
                        draw.drawPoint(p[1],p[0])  
                        forq in self.lw._get_neighbors(p):  
                            draw.drawPoint(q[1],q[0])  
                        p= self.path_map[p]  
                ifself.path:  
                    draw.setPen(QtCore.Qt.green)  
                    forp in self.path:  
                        draw.drawPoint(p[1],p[0])  
                        forq in self.lw._get_neighbors(p):  
                            draw.drawPoint(q[1],q[0])  
                draw.end()  
                        
                image.save(filepath,quality=100)  
                    
            else:  
                self.seed= p  
                        
                ifself.path_map:  
                    whilep != seed:  
                        p= self.path_map[p]  
                        self.path.append(p)  
                        
                #Calculate path map  
                ifevent.buttons() == QtCore.Qt.LeftButton:  
                    Thread(target=self._cal_path_matrix).start()  
                    Thread(target=self._update_path_map_progress).start()  
                        
                #Finish current task and reset  
                elifevent.buttons() == QtCore.Qt.RightButton:  
                    self.path_map= {}  
                    self.status_bar.showMessage('Leftclick to set a seed')  
                    self.active= False
            
    defmouseMoveEvent(self, event):  
        ifself.active and event.buttons() == QtCore.Qt.NoButton:  
            pos= event.pos()  
            x,y = pos.x()-self.canvas.x(), pos.y()-self.canvas.y()  
        
            ifx < 0 or x >= self.w or y < 0 or y >= self.h:  
                pass
            else:  
                #Draw livewire  
                p= y, x  
                path= []  
                whilep != self.seed:  
                    p= self.path_map[p]  
                    path.append(p)  
                        
                image= self.image.copy()  
                draw= QtGui.QPainter()  
                draw.begin(image)  
                draw.setPen(QtCore.Qt.blue)  
                forp in path:  
                    draw.drawPoint(p[1],p[0])  
                ifself.path:  
                    draw.setPen(QtCore.Qt.green)  
                    forp in self.path:  
                        draw.drawPoint(p[1],p[0])  
                draw.end()  
                self.canvas.setPixmap(image)  
            
    def_cal_path_matrix(self):  
        self.seed_enabled= False
        self.active= False
        self.status_bar.showMessage('Calculatingpath map...')  
        path_matrix= self.lw.get_path_matrix(self.seed)  
        self.status_bar.showMessage(r'Left:new seed / Right: finish')  
        self.seed_enabled= True
        self.active= True
                
        self.path_map= path_matrix  
            
    def_update_path_map_progress(self):  
        whilenot self.seed_enabled:  
            time.sleep(0.1)  
            message= 'Calculating path map... {:.1f}%'.format(self.lw.n_processed/self.lw.n_pixs*100.0)  
            self.status_bar.showMessage(message)  
        self.status_bar.showMessage(r'Left:new seed / Right: finish')  
        
gui.py
importsys  
fromPyQt4 import QtGui  
from guiimport ImageWin  
        
defmain():  
    app= QtGui.QApplication(sys.argv)  
    window= ImageWin()  
    window.setMouseTracking(True)  
    window.setWindowTitle('LivewireDemo')  
    window.show()  
    window.setFixedSize(window.size())  
    sys.exit(app.exec_())  
        
if__name__ == '__main__':  
    main()      
        
main.py

蛋疼地上传到了Github(传送门),欢迎fork。

效率的改进

因为这个代码的原型只是为了用C++开发之前的Python评估和验证,所以完全没考虑效率,执行速度是完全不行的,基本上400x400的图片就不能忍了……至于基于Python版本的效率提升我没有仔细想过,只是大概看来有这么几个比较明显的地方:

1) 取出当前最小cost像素操作

p = min(cost, key=cost.get)

这个虽然写起来很爽但显然是不行的,至少得用个min heap什么的。因为我是用dict同时表示待处理像素和cost了,也懒得想一下怎么和Python的heapq结合起来,所以直接用了粗暴省事的min()。

2) 梯度方向的计算

三角函数的计算应该是尽量避免的,另外原文可能是为了将值域扩展到>π所以把q-p也用上了,其实这一项本来权重就小,那怕直接用两个像素各自的梯度方向向量做点积然后归一化一下结果也是还行的。即使要用arccos,也可以考虑写个look-up table近似。当然我最后想说的是个人觉得其实这项真没那么必要,直接自适应spilne或者那怕三点均值平滑去噪效果就不错了。

3) 计算相邻像素的位置

如果两个像素相邻,则他们各自周围的8个相邻像素也会重合。的我的办法比较原始,可以考率不用模块化直接计算。

4) 替换部分数据结构

比如path map其实本质是给出每个像素在最短路径上的上一个像素,是个矩阵。其实可以考虑用线性的数据结构代替,不过如果真这样做一般来说都是在C/C++代码里了。

5)numpy

我印象中对numpy的调用顺序也会影响到效率,连续调用numpy的内置方法似乎会带来效率的整体提升,不过话还是说回来,实际应用中如果到了这一步,应该也属于C/C++代码范畴了。

6) 算法层面的改进

这块没有深入研究,第一感觉是实际应用中没必要一上来就计算整幅图像,可以根据seed位置做一些区块划分,鼠标本身也会留下轨迹,也或许可以考虑只在鼠标轨迹方向进行启发式搜索。另外计算路径的时候也许可以考虑借鉴有点类似于Image Pyramid的思想,没必要一上来就对全分辨率下的路径进行查找。由于后来做的项目没有采用这个算法,所以我也没有继续研究,虽然挺好奇的,其实有好多现成的代码,比如GIMP,不过没有精力去看了。

更多的改进

虽然都没做,大概介绍一下,都是考虑了实用性的改进。

路径冷却(Path Cooling)

用过Photoshop和GIMP磁力套索的人都知道,即使鼠标不点击图片,在移动过程中也会自动生成一些将抠图轨迹固定住的点,这些点其实就是新的种子点,而这种使用过程中自动生成新的种子点的方法叫Path cooling。这个方法的基本思路如下:随着鼠标移动过程中如果一定时间内一段路径都保持固定不变,那么就把这段路径中离种子最远的点设置为新的种子,其实背后隐藏的还是动态规划的思想,贝尔曼最优。这个名字也是比较形象的,路径冷却。

动态训练(InteractiveDynamic Training)

PS中磁性套索工具实现算法_第7张图片

单纯的最短路径查找在使用的时候常常出现找到的边缘不是想要的边缘的问题,比如上图,绿色的线是上一段提取的边缘,蓝色的是当前正在提取的边缘。左图中,镜子外面Lena的帽子边缘是我们想要提取的,然而由于镜子里的Lena的帽子边缘的cost更低,所以实际提取出的蓝色线段如右图中贴到右边了。所以InteractiveDynamic Training的思想是,认为绿色的线段是正确提取的边缘,然后利用绿色线段作为训练数据来给当前提取边缘的cost函数附加一个修正值。

[1]中采用的方法是统计前一段边缘上点的梯度强度的直方图,然后按照梯度出现频率给当前图中的像素加权。举例来说如果绿色线段中的所有像素对应的梯度强度都是在50到100之间的话,那么可以将50到100以10为单位分为5个bin,统计每个bin里的出现频率,也就是直方图,然后对当前检测到的梯度强度,做个线性加权。比方说50~60区间内对应的像素最多有10个,那么把10作为最大值,并且对当前检测到的梯度强度处于50~60之间的像素均乘上系数1.0;如果训练数据中70~80之间有5个,那么cost加权系数为5/10=0.5,则对所有当前检测到的梯度强度处于70~80之间的像素均乘上系数0.5;如果训练数据中100以上没有,所以cost附加为0/10=0,则加权系数为0,这样即使检测到更强的边缘也不会偏离前一段边缘了。这是基本思想,当然实际的实现没有这么简单,除了边缘上的像素还要考虑垂直边缘上左边和右边的两个像素点,这样保证了边缘的pattern。另外随着鼠标越来越远离训练边缘,检测到的边缘的pattern可能会出现不一样,所以Training可能会起反作用,所以这种Training的作用范围也需要考虑到鼠标离种子点的距离,最后还要有一些平滑去噪的处理,具体都在[1]里有讲到,挺繁琐的(那会好像还没有SIFT),不详述了。

种子点位置的修正(Cursor Snap)

虽然这个算法可以自动找出种子点和鼠标之间最贴近边缘的路径,不过,人的手,常常抖,所以种子点未必能很好地设置到边缘上。所以可以在用户设置完种子点位置之后,自动在其坐标周围小范围内,比如7x7的区域内搜索cost最低的像素,作为真正的种子点位置,这个过程叫做Cursor snap。

 


你可能感兴趣的:(算法)