【CNN基础】常见的loss函数及其实现(一)——TV Loss

本文主要关注潜在有效的,值得炼丹的Loss函数:
TV loss

Total Variation loss

在图像复原过程中,图像上的一点点噪声可能就会对复原的结果产生非常大的影响,因为很多复原算法都会放大噪声。这时候我们就需要在最优化问题的模型中添加一些正则项来保持图像的光滑性,TV loss是常用的一种正则项(注意是正则项,配合其他loss一起使用,约束噪声)。图片中相邻像素值的差异可以通过降低TV loss来一定程度上解决。比如降噪,对抗checkerboard等等。

1. 初始定义

       ~~~~~~       Rudin等人(Rudin1990)观察到,受噪声污染的图像的总变分比无噪图像的总变分明显的大。 那么最小化TV理论上就可以最小化噪声。图片中相邻像素值的差异可以通过降低TV loss来一定程度上解决。比如降噪,对抗checkerboard等等。总变分定义为梯度幅值的积分:
J T 0 ( u ) = ∫ Ω u ∣ ▽ u ∣ d x d y = ∫ D u u x 2 + u y 2 d x d y J_{T_0}(u)=\int\limits_{\Omega_{u}}|\triangledown_u|dxdy=\int\limits_{D_u}\sqrt{u_x^2+u_y^2}dxdy JT0(u)=Ωuudxdy=Duux2+uy2 dxdy
       ~~~~~~       其中, u x = ∂ u ∂ x u_x=\frac {\partial u}{\partial x} ux=xu u y = ∂ u ∂ y u_y=\frac {\partial u}{\partial y} uy=yu D u D_u Du是图像的支持域。限制总变分就会限制噪声。

2. 扩展定义

       ~~~~~~       带阶数的TV loss 定义如下:
ℜ V β ( f ) = ∫ Ω ( ∂ f ∂ u ( u , v ) 2 + ∂ f ∂ v ( u , v ) 2 ) β 2 \Re _{V^{\beta }}\left ( f\right )=\int_{\Omega }\left ( \frac{\partial f}{\partial u}\left ( u,v \right )^{2}+\frac{\partial f}{\partial v}\left ( u,v \right )^{2} \right )^{\frac{\beta }{2}} Vβ(f)=Ω(uf(u,v)2+vf(u,v)2)2β
       ~~~~~~       但是在图像中,连续域的积分就变成了像素离散域中求和,所以可以这么算:
ℜ V β ( x ) = ∑ i , j ( ( x i , j − 1 − x i , j ) 2 + ( x i + 1 , j − x i , j ) 2 ) β 2 \Re _{V^{\beta }}\left ( x\right )=\sum_{i,j}\left (\left ( x_{i,j-1}-x_{i,j}\right )^{2}+\left ( x_{i+1,j}-x_{i,j} \right )^{2} \right )^{\frac{\beta }{2}} Vβ(x)=i,j((xi,j1xi,j)2+(xi+1,jxi,j)2)2β
       ~~~~~~       即:求每一个像素和横向下一个像素的差的平方,加上纵向下一个像素的差的平方。然后开β/2次根。

3. 效果

       ~~~~~~       The total variation (TV) loss encourages spatial smoothness in the generated image.(总变差(TV)损失促进了生成的图像中的空间平滑性。)根据论文Nonlinear total variation based noise removal algorithms的描述,当β < 1时,会出现下图左侧的小点点的artifact。当β > 1时,图像中小点点会被消除,但是代价就是图像的清晰度。效果图如下:
【CNN基础】常见的loss函数及其实现(一)——TV Loss_第1张图片

4. 代码实现

这两种实现都默认 β = 2 \beta=2 β=2,不支持 β \beta β的调整。

4.1 pytorch

import torch
import torch.nn as nn
from torch.autograd import Variable

class TVLoss(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(TVLoss,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:,:,1:,:])
        count_w = self._tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]

def main():
    # x = Variable(torch.FloatTensor([[[1,2],[2,3]],[[1,2],[2,3]]]).view(1,2,2,2), requires_grad=True)
    # x = Variable(torch.FloatTensor([[[3,1],[4,3]],[[3,1],[4,3]]]).view(1,2,2,2), requires_grad=True)
    # x = Variable(torch.FloatTensor([[[1,1,1], [2,2,2],[3,3,3]],[[1,1,1], [2,2,2],[3,3,3]]]).view(1, 2, 3, 3), requires_grad=True)
    x = Variable(torch.FloatTensor([[[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]]]).view(1, 2, 3, 3),requires_grad=True)
    addition = TVLoss()
    z = addition(x)
    print x
    print z.data
    z.backward()
    print x.grad
    
if __name__ == '__main__':

4.2 tensorflow

def tv_loss(X, weight):
    with tf.variable_scope('tv_loss'):
        return weight * tf.reduce_sum(tf.image.total_variation(X))

4. 参考资料

       ~~~~~~       本章节参考以下资料,作一定的整理,方便他人阅读与研究:

  1. wiki上关于TVLoss的描述:https://en.wikipedia.org/wiki/Total_variation_denoising,https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures
  2. CSDN博客《Total Variation》
  3. 视频教程Denoising, deconvolution and computed tomography using total variation penalty
  4. 实验——基于pytorch的噪声估计网络
  5. pytorch的TV loss实现

你可能感兴趣的:(CNN基础)