DBNet详解

文章目录

  • 创新点
  • 算法的整体架构
  • 自适应阈值(Adaptive threshhold)
  • 二值化
    • 标准二值化
    • 可微二值(differentiable Binarization)
    • 直观展示
  • 可形变卷积(Deformable convolution)
  • 标签的生成
    • PSENet标签生成
    • DBNet标签生成
  • 损失函数
  • 后处理
  • 代码阅读
    • 数据预处理
      • 入口
      • AugmentDetectionData(数据增强类)
      • RandomCropData(数据裁剪类)
      • MakeICDARData(数据重新组织类)
      • MakeSegDetectionData(生成概率图和对应mask类)
      • MakeBorderMap(生成阈值图和对应Mask类)
      • NormalizeImage
      • FilterKeys
    • 模型结构
      • 骨干网络和FPN
      • head部分(decoder)
        • binary
        • thresh
        • step_function
    • 损失函数
      • binary loss
      • thresh loss
      • thresh_binary loss
    • 逻辑推理
  • 补充
    • 语义分割中的loss function
      • cross entropy loss
      • weighted loss
      • focal loss
      • dice soft loss
        • Dice系数计算
        • Dice loss
        • 梯度分析
        • 总结
      • soft IOU loss
    • 总结
        • 总结
      • soft IOU loss
    • 总结

创新点

​ 本文的最大创新点。在基于分割的文本检测网络中,最终的二值化map都是使用的固定阈值来获取,并且阈值不同对性能影响较大。本文中,对每一个像素点进行自适应二值化,二值化阈值由网络学习得到,彻底将二值化这一步骤加入到网络里一起训练,这样最终的输出图对于阈值就会非常鲁棒。

和常规基于语义分割算法的区别是多了一条threshold map分支,该分支的主要目的是和分割图联合得到更接近二值化的二值图,属于辅助分支。其余操作就没啥了。整个核心知识就这些了。

算法的整体架构

DBNet详解_第1张图片

  • 首先,图像输入特征提取主干,提取特征;
  • 其次,特征金字塔上采样到相同的尺寸,并进行特征级联得到特征F;
  • 然后,特征F用于预测概率图(probability map P)和阈值图(threshold map T)
  • 最后,通过P和F计算近似二值图(approximate binary map B)

在训练期间对P,T,B进行监督训练,P和B是用的相同的监督信号(label)。在推理时,只需要P或B就可以得到文本框。

网络输出:

1.probability map, w*h*1 , 代表像素点是文本的概率

2.threshhold map, w*h*1, 每个像素点的阈值

3.binary map, w*h*1, 由1,2计算得到,计算公式为DB公式

自适应阈值(Adaptive threshhold)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-V5RaeceH-1610966579215)(C:\F\notebook\DB\20200922201346491.png)]

文中指出传统的文本检测算法主要是图中蓝色线,处理流程如下:

  • 首先,通过设置一个固定阈值将分割网络训练得到的概率图(segmentation map)转化为二值图(binarization map);
  • 然后,使用一些启发式技术(例如像素聚类)将像素分组为文本实例。

而DBNet使用红色线,思路:

通过网络去预测图片每个位置处的阈值,而不是采用一个固定的值,这样就可以很好将背景与前景分离出来,但是这样的操作会给训练带来梯度不可微的情况,对此对于二值化提出了一个叫做Differentiable Binarization来解决不可微的问题。

​ 阈值图(threshhold map)使用流程如图2所示,使用阈值map和不使用阈值map的效果对比如图6所示,从图6©中可以看到,即使没用带监督的阈值map,阈值map也会突出显示文本边界区域,这说明边界型阈值map对最终结果是有利的。所以,本文在阈值map上选择监督训练,已达到更好的表现

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4gDERWPU-1610966579221)(C:\F\notebook\DB\20200922201612829.png)]

二值化

标准二值化

​ 一般使用分割网络(segmentation network)产生的概率图(probability map P),将P转化为一个二值图P,当像素为1的时候,认定其为有效的文本区域,同时二值处理过程:

img

i和j代表了坐标点的坐标,t是预定义的阈值;

可微二值(differentiable Binarization)

公式1是不可微的,所以没法直接用于训练,本文提出可微的二值化函数,如下(其实就是一个带系数的sigmoid):

img

img就是近似二值图;T代表从网络中学习到的自适应阈值图;k是膨胀因子(经验性设置k=50).

​ 这个近似的二值化函数的表现类似于标准的二值化函数,如图4所表示,但是因为可微,所以可以直接用于网络训练,基于自适应阈值可微二值化不仅可以帮助区分文本区域和背景,而且可以将连接紧密的文本实例分离出来。

DBNet详解_第2张图片

​ 为了说明DB模块的引入对于联合训练的优势,作者对该函数进行梯度分析,也就是对approximate
binary map进行求导分析,由于是sigmod输出,故假设Loss是bce,对于label为0或者1的位置,其Loss函数可以重写为:
DBNet详解_第3张图片
x表示probability map-threshold map,最后一层关于x的梯度很容易计算:
DBNet详解_第4张图片

​ 看上图右边,(b)图是当label=1,x预测值从-1到1的梯度,可以发现,当k=50时候梯度远远大于k=1,错误的区域梯度更大,对于label=0的情况分析也是一样的。故:
(1) 通过增加参数K,就可以达到增大梯度的目的,加快收敛
(2) 在预测错误位置,梯度也是显著增加

总之通过引入DB模块,通过参数K可以达到增加梯度幅值,更加有利优化,可以使得三个输出图优化更好,最终分割结果会优异。而DB模块本身就是带参数的sigmod函数,实现如下:
image.png-21.3kB

直观展示

p可以理解,就是有文字的区域有值0.9以上,没有文字区域黑的为0 .

T是一个只有文字边界才有值的,其他地方为0 .

DBNet详解_第5张图片

​ 分别是原图,gt图,threshold map图。 这里再说下threshold map图,非文字边界处都是灰色的,这是因为统一加了0.3,所有最小值是0.3.

这里其实还看不清,我们把src+gt+threshold map看看。
DBNet详解_第6张图片

可以看到:

  • p的ground truth是标注缩水之后
  • T的ground truth是文字块边缘分别向内向外收缩和扩张
  • p与T是公式里面的那两个变量。

再看这个公式与曲线图:

P和T我们就用ground truth带入来理解: DBNet详解_第7张图片

​ P网络学的文字块内部, T网络学的文字边缘,两者计算得到B。 B的ground truth也是标注缩水之后,和p用的同一个。 在实际操作中,作者把除了文字块边缘的区域置为0.3.应该就是为了当在非文字区域, P=0,T=0.3,x=p-T<0这样拉到负半轴更有利于区分。

可形变卷积(Deformable convolution)

​ 可变形卷积可以提供模型一个灵活的感受野,这对于不同纵横比的文本很有利,本文应用可变形卷积,使用3×3卷积核在ResNet-18或者ResNet-50的conv3,conv4,conv5层。

标签的生成

概率图的标签产成法类似PSENet

PSENet标签生成

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8eRCYRHv-1610966579234)(C:\F\notebook\DB\20200923193744225.png)]

​ 网络输出多个分割结果(S1,Sn),因此训练时需要有多个GY与其匹配,在本文中,通过收缩原始标签就可以简单高效的生成不同尺度的GT,如图5所示,(b)代表原始的标注结果,也表示最大的分割标签mask,即Sn,利用Vatti裁剪算法获取其他尺度的Mask,如图5(a),将原始多边形pn 缩小di 像素到 pi ,收缩后的pi 转换成0/1的二值mask作为GT,用G1,G2,,,,Gn分别代表不同尺度的GT,用数学方式表示的话,尺度比例为ri

di 的计算方式为:
d i = A r e a ( P n ) ∗ ( 1 − r i 2 ) / P e r i m e t e r ( p n ) d_i=Area(P_n)*(1-r_i^2)/Perimeter(p_n) di=Area(Pn)(1ri2)/Perimeter(pn)
Area(·) 是计算多边形面积的函数, Perimeter(·)是计算多边形周长的函数,生成Gi时的尺度比例ri计算公式为:
r i = 1 − ( 1 − m ) ∗ ( n − i ) / ( n − 1 ) r_i=1-(1-m)*(n-i)/(n-1) ri=1(1m)(ni)/(n1)

m代表最小的尺度比例,取值范围是(0,1],使用上式,通过m和n两个超参数可以计算出r1,r2,…rn,他们随着m变现线性增加到最大值1.

DBNet标签生成

DBNet详解_第8张图片

给定一张图片,文本区域标注的多边形可以描述为:
G = { S k } k = 1 n G=\{S_k\}_{k=1}^{n} G={ Sk}k=1n
n是每隔文本框的标注点总数,在不同数据中可能不同,然后使用vatti裁剪算法,将正样例区域产生通过收缩polygon从G到Gs,补偿公式计算

img

D:offset;L:周长;A:面积;r:收缩比例,设置为0.4;

  1. probability map, 按照pse的方式制作即可,收缩比例设置为0.4
  2. threshold map, 将文本框分别向内向外收缩和扩张d(根据第一步收缩时计算得到)个像素,然后计算收缩框和扩张框之间差集部分里每个像素点到原始图像边界的归一化距离,此处有个问题,两个邻近的文本框,在扩张后会重叠,这种情况下重叠部分像素点的距离使用哪个文本框的?

DBNet详解_第9张图片

损失函数

损失函数为概率map的loss、二值map的loss和阈值map的loss之和。

img

Ls 是概率map的loss,Lb 是二值map的loss,均使用二值交叉熵loss(BCE),为了解决正负样本不均衡问题,使用hard negative mining, α和β分别设置为1.0和10 .

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-udDAMvqX-1610966579237)(C:\F\notebook\DB\2020092220283134.png)]

Sl 设计样本集,其中正阳样本和负样本比例是1:3

Lt计算方式为扩展文本多边形Gd内预测结果和标签之间的L1距离之和:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pGt5zzIG-1610966579239)(C:\F\notebook\DB\20200922203558285.png)]

Rd是在膨胀Gd内像素的索引,y*是阈值map的标签

后处理

(由于threshold map的存在,probability map的边界可以学习的很好,因此可以直接按照收缩的方式(Vatti clipping algorithm)扩张回去 )

在推理时可以采用概率图或近似二值图来生成文本框,为了方便作者选择了概率图,具体步骤如下:

1、使用固定阈值0.2将概率图做二值化得到二值化图;

2、由二值化图得到收缩文字区域;

3、将收缩文字区域按Vatti clipping算法的偏移系数D’通过膨胀再扩展回来。

img

D‘就是扩展补偿,A’是收缩多边形的面积,L‘就是收缩多边形的周长,r’作者设置的是1.5;

注意r‘的值在DBNet工程中不是1.5,而在我自己的数据集上,参数设置为1.3较合适,大家训练的时候可以根据自己模型效果进行调整

文中说明DB算法的主要优势有以下4点:

  • 在五个基准数据集上有良好的表现,其中包括水平、多个方向、弯曲的文本。
  • 比之前的方法要快很多,因为DB可以提供健壮的二值化图,从而大大简化了后处理过程。
  • 使用轻量级的backbone(ResNet18)也有很好的表现。
  • DB模块在推理过程中可以去除,因此不占用额外的内存和时间的消耗。

参考:

论文链接:https://arxiv.org/pdf/1911.08947.pdf

工程链接:https://github.com/MhLiao/DB

​ https://github.com/WenmuZhou/DBNet.pytorch

  • https://blog.csdn.net/qq_22764813/article/details/107785388
  • https://blog.csdn.net/qq_39707285/article/details/108739010
  • https://zhuanlan.zhihu.com/p/94677957
  • https://mp.weixin.qq.com/s/ehbROyE-grp_F3T3YBX9CA

代码阅读

数据预处理

入口

在data/image_dataset.py,数据预处理逻辑非常简单,就是读取图片和gt标注,解析出每张图片poly标注,包括多边形标注、字符内容以及是否是忽略文本,忽略文本一般是比较模糊和小的文本。

具体可以在getitem方法里面插入:

ImageDataset.__getitem__():
	data_process(data)

预处理配置

 processes:
        - class: AugmentDetectionData
          augmenter_args:
              - ['Fliplr', 0.5]
              - {
     'cls': 'Affine', 'rotate': [-10, 10]}
              - ['Resize', [0.5, 3.0]]
          only_resize: False
          keep_ratio: False
        - class: RandomCropData
          size: [640, 640]
          max_tries: 10
        - class: MakeICDARData
        - class: MakeSegDetectionData
        - class: MakeBorderMap
        - class: NormalizeImage
        - class: FilterKeys
          superfluous: ['polygons', 'filename', 'shape', 'ignore_tags', 'is_training']

预处理流程:

AugmentDetectionData(数据增强类)

DB/data/processes/augment_data.py

​ 其目的就是对图片和poly标注进行数据增强,包括翻转、旋转和缩放三个,参数如配置所示。本文采用的增强库是imgaug。可以看出本文训练阶段对数据是不保存比例的resize,然后再进行三种增强。

由于icdar数据,文本区域占比都是非常小的,故不能用直接resize到指定输入大小的数据增强操作,而是使用后续的randcrop操作比较科学。但是如果自己项目的数据文本区域比较大,则可能没必要采用RandomCropData这么复杂的数据增强操作,直接resize算了。

RandomCropData(数据裁剪类)

DB/data/processes/random_crop_data.py

因为数据裁剪涉及到比较复杂的多变形标注后处理,所以单独列出来 。

​ 其目的是对图片进行裁剪到指定的[640, 640]。由于斜框的特点,裁剪增强没那么容易做,本文采用的裁剪策略非常简单: 遍历每一个多边形标注,只要裁剪后有至少有一个poly还在裁剪框内,则认为该次裁剪有效。这个策略主要可以保证一张图片中至少有一个gt,且实现比较简单。

其具体流程是:

  1. 将每张图片的所有poly数据进行水平和垂直方向投影,有标注的地方是1,其余地方是0
  2. 找出没有标注即0值的水平和垂直坐标h_axis和w_axis
  3. 如果全部是1,则表示poly横跨了整图,则直接返回,无法裁剪
  4. 对水平和垂直坐标进行连续0区域分离,其实就是把所有连续0坐标区域切割处理变成List输出h_regions、w_regions
  5. 以w_regions为例,长度为n,先从n个区域随机选择2个区域,然后在这两个区域内部随机选择两个点,构成x方向最大最小坐标,h_regions也是一样处理,此时就得到了xmin, ymin, xmax - xmin, ymax - ymin值
  6. 判断裁剪区域是否过小;以及判断是否裁剪框内部是否至少有一个标注在内部,没有被裁断,如果条件满足则返回上述值,否则重复max_tries次,直到成功。

代码如下:

    def crop_area(self, im, text_polys):
        h, w = im.shape[:2]
        h_array = np.zeros(h, dtype=np.int32)
        w_array = np.zeros(w, dtype=np.int32)
        #将poly数据进行水平和垂直方向投影,有标注的地方是1,其余地方是0
        for points in text_polys:
            points = np.round(points, decimals=0).astype(np.int32)
            minx = np.min(points[:, 0])
            maxx = np.max(points[:, 0])
            w_array[minx:maxx] = 1
            miny = np.min(points[:, 1])
            maxy = np.max(points[:, 1])
            h_array[miny:maxy] = 1
        # ensure the cropped area not across a text
        #找出没有标注的水平和垂直坐标
        h_axis = np.where(h_array == 0)[0]
        w_axis = np.where(w_array == 0)[0]
        #如果所有位置都有标注,则无法裁剪,直接原图返回
        if len(h_axis) == 0 or len(w_axis) == 0:
            return 0, 0, w, h
        
        #对水平和垂直坐标进行连续区域分离,其实就是把所有连续0坐标区域切割处理
        #后面进行随机裁剪都是在每个连续区域进行,可以最大程度保证不会裁断标注
        h_regions = self.split_regions(h_axis)
        w_regions = self.split_regions(w_axis)

        for i in range(self.max_tries):
            if len(w_regions) > 1:
                #先从n个区域随机选择2个区域,然后在两个区域内部随机选择两个点,构成x方向最大最小坐标
                xmin, xmax = self.region_wise_random_select(w_regions, w)
            else:
                xmin, xmax = self.random_select(w_axis, w)
            if len(h_regions) > 1:
                #h方向也是一样处理
                ymin, ymax = self.region_wise_random_select(h_regions, h)
            else:
                ymin, ymax = self.random_select(h_axis, h)
            #不能裁剪的过小
            if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
                # area too small
                continue
            num_poly_in_rect = 0
            for poly in text_polys:
                #如果有一个poly标注没有出界,则直接返回,表示裁剪成功
                if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):
                    num_poly_in_rect += 1
                    break

            if num_poly_in_rect > 0:
                return xmin, ymin, xmax - xmin, ymax - ymin

        return 0, 0, w, h

​ 在得到裁剪区域后,就比较简单了。先对裁剪区域图片进行保存长宽比的resize,最长边为网络输入,例如640x640, 然后从上到下pad,得到640x640的图片


# 计算crop区域
crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys)
# crop 图片 保持比例填充
scale_w = self.size[0] / crop_w
scale_h = self.size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)

padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype)
padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg

如果进行可视化,会显示如下所示:
DBNet详解_第10张图片
DBNet详解_第11张图片

可以看出,这种裁剪策略虽然简单暴力,但是为了拼接成640x640的输出,会带来大量无关全黑像素区域。

MakeICDARData(数据重新组织类)

DB/data/processes/make_icdar_data.py

就是简单的组织数据而已

#Making ICDAE format
#返回值:
OrderedDict(image=data['image'],
                           polygons=polygons,
                           ignore_tags=ignore_tags,
                           shape=shape,
                           filename=filename,
                           is_training=data['is_training'])

MakeSegDetectionData(生成概率图和对应mask类)

DB/data/processes/make_seg_detection_data.py

功能:将多边形数据转化为mask格式即概率图gt,并且标记哪些多边形是忽略区域


#Making binary mask from detection data with ICDAR format
输入:image,polygons,ignore_tags,filename
输出:gt(shape:[1,h,w]),mask (shape:[h,w])(用于后面计算binary loss)

DBNet详解_第12张图片

​ 为了防止标注间相互粘连,不好后处理,区分实例,目前做法都是会进行shrink即沿着多边形标注的每条边进行向内缩减一定像素,得到缩减的gt,然后才进行训练;在测试时候再采用相反的手动还原回来。

​ 缩减做法采用的也是常规的Vatti clipping algorithm,是通过pyclipper库实现的,缩减比例是默认0.4,公式是:
image.png-6.2kB

r=0.4,A是多边形面积,L是多边形周长,通过该公式就可以对每个不同大小的多边形计算得到一个唯一的D,代表每条边的向内缩放像素个数。

        gt = np.zeros((1, h, w), dtype=np.float32)#shrink后得到概率图,包括所有区域
        mask = np.ones((h, w), dtype=np.float32)#指示哪些区域是忽略区域,0就是忽略区域
        for i in range(len(polygons)):
            polygon = polygons[i]
            height = max(polygon[:, 1]) - min(polygon[:, 1])
            width = max(polygon[:, 0]) - min(polygon[:, 0])
            #如果是忽略样本,或者高宽过小,则mask对应位置设置为0即可
            if ignore_tags[i] or min(height, width) < self.min_text_size:
                cv2.fillPoly(mask, polygon.astype(
                    np.int32)[np.newaxis, :, :], 0)
                ignore_tags[i] = True
            else:
                #沿着每条边进行shrink
                polygon_shape = Polygon(polygon)#多边形分析库
                #每条边收缩距离:polygon, D=A(1-r^2)/L
                distance = polygon_shape.area * \
                    (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
                subject = [tuple(l) for l in polygons[i]]
                #实现坐标的偏移
                padding = pyclipper.PyclipperOffset()
                padding.AddPath(subject, pyclipper.JT_ROUND,
                                pyclipper.ET_CLOSEDPOLYGON)
                shrinked = padding.Execute(-distance)#得到缩放后的多边形
                if shrinked == []:
                    cv2.fillPoly(mask, polygon.astype(
                        np.int32)[np.newaxis, :, :], 0)
                    ignore_tags[i] = True
                    continue
                shrinked = np.array(shrinked[0]).reshape(-1, 2)
                cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)

如果进行可视化,如下所示:
DBNet详解_第13张图片
DBNet详解_第14张图片

​ 概率图内部全白区域就是概率图的label,右图是忽略区域mask,0为忽略区域,到时候该区域是不计算概率图loss的。

MakeBorderMap(生成阈值图和对应Mask类)

DB/data/make_border_map.py

功能:计算阈值图和对应mask。

输入:预处理后的image info: image,polygons,ignore_tags 
输出:thresh_map,thresh_mask (用于后面计算thresh loss)

DBNet详解_第15张图片

​ 仔细看阈值图的标注,首先红线点是poly标注;然后对该多边形先进行shrink操作,得到蓝线; 然后向外反向shrink同样的距离,得到绿色;阈值图就是绿线和蓝色区域,以红线为起点,计算在绿线和蓝线区域内的点距离红线的距离,故为距离图。

其代码的处理逻辑是:

  1. 对每个poly进行向外扩展,参数和向内shrink一样,然后对扩展后多边形内部填充1,得到对应的mask
  2. 为了加快计算速度,对每条poly计算最小包围矩,然后在裁剪后的图片内部,计算每个点到poly上面每条边的距离
  3. 只保留0-1值内的距离值,其余位置不用
  4. 把距离图贴到原图大小的图片上,如果和其余poly有重叠,则取最大值
  5. 为了使得后续阈值图和概率图进行带参数的sigmod操作,得到近似二值图,需要对阈值图的取值范围进行变换,具体是将0-1范围变换到0.3-0.6范围

流程:

canvas = np.zeros(image.shape[:2], dtype=np.float32)
mask = np.zeros(image.shape[:2], dtype=np.float32)

draw_border_map(polygons[i], canvas, mask=mask)
canvas = canvas * (0.7 - 0.3) + 0.3
data['thresh_map'] = canvas
data['thresh_mask'] = mask
        

draw_border_map

    #处理每条poly
    def draw_border_map(self, polygon, canvas, mask):
        polygon = np.array(polygon)
        assert polygon.ndim == 2
        assert polygon.shape[1] == 2
        #向外扩展
        polygon_shape = Polygon(polygon)
        distance = polygon_shape.area * \
            (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
        subject = [tuple(l) for l in polygon]
        padding = pyclipper.PyclipperOffset()
        padding.AddPath(subject, pyclipper.JT_ROUND,
                        pyclipper.ET_CLOSEDPOLYGON)
        padded_polygon = np.array(padding.Execute(distance)[0])#shape:[12,2]扩大和缩减一样的像素
        cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)#内部全部填充1
        #计算最小包围poly矩形
        xmin = padded_polygon[:, 0].min()
        xmax = padded_polygon[:, 0].max()
        ymin = padded_polygon[:, 1].min()
        ymax = padded_polygon[:, 1].max()
        width = xmax - xmin + 1
        height = ymax - ymin + 1
        #裁剪掉无关区域,加快计算速度
        polygon[:, 0] = polygon[:, 0] - xmin
        polygon[:, 1] = polygon[:, 1] - ymin
        #最小包围矩形的所有位置坐标
        xs = np.broadcast_to(
            np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
        ys = np.broadcast_to(
            np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))

        distance_map = np.zeros(
            (polygon.shape[0], height, width), dtype=np.float32)
        for i in range(polygon.shape[0]):#对每条边进行遍历
            j = (i + 1) % polygon.shape[0]
            #计算图片上所有点到线上面的距离
            absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
            #仅仅保留0-1之间的位置,得到距离图
            distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
        distance_map = distance_map.min(axis=0)
        #绘制到原图上
        xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
        xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
        ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
        ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
        #如果有多个ploy实例重合,则该区域取最大值
        canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
            1 - distance_map[
                ymin_valid-ymin:ymax_valid-ymax+height,
                xmin_valid-xmin:xmax_valid-xmax+width],
            canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])

可视化如下所示:
DBNet详解_第16张图片
采用matpoltlib绘制距离图会更好看
DBNet详解_第17张图片

NormalizeImage

DB/data/processes/normalize_image.py

图片归一化类

FilterKeys

DB/data/processes/filter_keys.py

字典数据过滤类,具体是把superfluous里面的key和value删掉,不输入网络中

#删除无用的图片信息,只保留信息:
dict"image","gt","mask","thresh_map","thresh_mask"

模型结构

DB/structure/model.py

模型结构配置部分:

        builder: 
            class: Builder
            model: SegDetectorModel
            model_args:
                backbone: deformable_resnet18
                decoder: SegDetector
                decoder_args: 
                    adaptive: True
                    in_channels: [64, 128, 256, 512]
                    k: 50

骨干网络和FPN

​ 骨架网络采用的是resnet18或者resnet50,为了增加网络特征提取能力,在layer2、layer3和layer4模块内部引入了变形卷积dcnv2模块。在resnet输出的4个特征图后面采用标准的FPN网络结构,得到4个增强后输出,然后cat进来,得到1/4的特征图输出fuse。

​ 其中,resnet骨架特征提取代码在backbones/resnet.py里,具体是输出x2, x3, x4, x5,分别是1/4~1/32尺寸。FPN部分代码在decoders/seg_detector.py里面.

head部分(decoder)

DB/decoders/seg_detector.py

​ 输出head在训练时候包括三个分支,分别是probability map、threshold map和经过DB模块计算得到的approximate binary map。三个图通道都是1,输出和输入是一样大的。要想分割精度高,高分辨率输出是必要的。

**输出:**binary、thresh、thresh_binary

fuse = torch.cat((p5, p4, p3, p2), 1)
#推理时,只需返回binary
binary = self.binarize(fuse)
thresh = self.thresh(fuse)
thresh_binary = self.step_function(binary, thresh)

binary

​ 对fuse特征图经过一系列卷积和反卷积,扩大到和原图一样大的输出,然后经过sigmod层得到0-1输出概率图probability map

        self.binarize = nn.Sequential(
            nn.Conv2d(inner_channels, inner_channels //
                      4, 3, padding=1, bias=bias),
            BatchNorm2d(inner_channels//4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(inner_channels//4, inner_channels//4, 2, 2),
            BatchNorm2d(inner_channels//4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(inner_channels//4, 1, 2, 2),
            nn.Sigmoid())
        self.binarize.apply(self.weights_init)

thresh

​ 同时对fuse特征图采用类似上采样操作,经过sigmod层的0-1输出阈值图threshold map

if adaptive:
    self.thresh = self._init_thresh(
    inner_channels, serial=serial, smooth=smooth, bias=bias)
    self.thresh.apply(self.weights_init)
  
  
  
  def _init_thresh(self, inner_channels,
                     serial=False, smooth=False, bias=False):
        in_channels = inner_channels
        if serial:
            in_channels += 1
        self.thresh = nn.Sequential(
            nn.Conv2d(in_channels, inner_channels //
                      4, 3, padding=1, bias=bias),
            BatchNorm2d(inner_channels//4),
            nn.ReLU(inplace=True),
            self._init_upsample(inner_channels // 4, inner_channels//4, smooth=smooth, bias=bias),
            BatchNorm2d(inner_channels//4),
            nn.ReLU(inplace=True),
            self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
            nn.Sigmoid())
        return self.thresh

step_function

​ 将这两个输出图经过DB模块得到approximate binary map

torch.reciprocal(1 + torch.exp(-self.k * (binary - thresh)))

损失函数

DB/decoders/seg_detector_loss.py

​ 输出是单个单通道图,probability map和approximate binary map是典型的分割输出,故其loss就是普通的bce,但是为了平衡正负样本,还额外采用了难负样本采样策略,对背景区域和前景区域采用3:1的设置。对于threshold map,其输出不一定是0-1之间,后面会介绍其值的范围,当前采用的是L1 loss,且仅仅计算扩展后的多边形内部区域,其余区域忽略。

image.png-5.4kB

Ls是概率图,Lt是阈值图,Lb是近似二值化图,

​ 本文整个论文Loss的实现在decoders/seg_detector_loss.py的L1BalanceCELoss类,可以发现其实approximate binary map采用的并不是论文中的bce,而是可以克服正负样本平衡的dice loss。一般在高度不平衡的二值分割任务中,dice loss效果会比纯bce好,但是更好的策略是dice loss +bce loss。

loss = dice_loss + 10 * l1_loss + 5*bce_loss 

DBNet详解_第18张图片

binary loss

bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask'])

bce_loss:

DB/decoders/balance_cross_entropy_loss.py


    def forward(self,
                pred: torch.Tensor,
                gt: torch.Tensor,
                mask: torch.Tensor,
                return_origin=False):
        '''
        Args:
            pred: shape :math:`(N, 1, H, W)`, the prediction of network
            gt: shape :math:`(N, 1, H, W)`, the target
            mask: shape :math:`(N, H, W)`, the mask indicates positive regions
        '''
        positive = (gt * mask).byte()
        negative = ((1 - gt) * mask).byte()
        positive_count = int(positive.float().sum())
        #负样本个数为positive_count的self.negative_ratio倍数
        negative_count = min(int(negative.float().sum()),
                            int(positive_count * self.negative_ratio))
        loss = nn.functional.binary_cross_entropy(
            pred, gt, reduction='none')[:, 0, :, :]
        positive_loss = loss * positive.float()
        negative_loss = loss * negative.float()
        #按照loss选择topK个
        negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)

        balance_loss = (positive_loss.sum() + negative_loss.sum()) /\
            (positive_count + negative_count + self.eps)

        if return_origin:
            return balance_loss, loss
        return balance_loss

thresh loss

l1_loss, l1_metric = self.l1_loss(pred['thresh'], batch['thresh_map'], batch['thresh_mask'])

l1_loss:

DB/decoders/l1_loss.py

class MaskL1Loss(nn.Module):
    def __init__(self):
        super(MaskL1Loss, self).__init__()

    def forward(self, pred: torch.Tensor, gt, mask):
        mask_sum = mask.sum()
        if mask_sum.item() == 0:
            return mask_sum, dict(l1_loss=mask_sum)
        else:
            loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sum
            return loss, dict(l1_loss=loss)

thresh_binary loss

dice_loss = self.dice_loss(pred['thresh_binary'], batch['gt'], batch['mask'])

dice_loss:

DB/decoders/dice_loss.py

class DiceLoss(nn.Module):
    '''
    Loss function from https://arxiv.org/abs/1707.03237,
    where iou computation is introduced heatmap manner to measure the
    diversity bwtween tow heatmaps.
    '''
    def __init__(self, eps=1e-6):
        super(DiceLoss, self).__init__()
        self.eps = eps

    def forward(self, pred: torch.Tensor, gt, mask, weights=None):
        '''
        pred: one or two heatmaps of shape (N, 1, H, W),
            the losses of tow heatmaps are added together.
        gt: (N, 1, H, W)
        mask: (N, H, W)
        '''
        assert pred.dim() == 4, pred.dim()
        return self._compute(pred, gt, mask, weights)

    def _compute(self, pred, gt, mask, weights):
        if pred.dim() == 4:
            pred = pred[:, 0, :, :]
            gt = gt[:, 0, :, :]
        assert pred.shape == gt.shape
        assert pred.shape == mask.shape
        if weights is not None:
            assert weights.shape == mask.shape
            mask = weights * mask

        intersection = (pred * gt * mask).sum()
        union = (pred * mask).sum() + (gt * mask).sum() + self.eps
        loss = 1 - 2.0 * intersection / union
        assert loss <= 1
        return loss

binary与thresh_binary的标签都是用的gt

thresh的标签用的thresh_map

逻辑推理

配置如下:

  - name: validate_data
    class: ImageDataset
    data_dir:
        - '/remote_workspace/ocr/public_dataset/icdar2015/'
    data_list:
        - '/remote_workspace/ocr/public_dataset/icdar2015/test_list.txt'
    processes:
        - class: AugmentDetectionData
          augmenter_args:
              - ['Resize', {
     'width': 1280, 'height': 736}]
              # - ['Resize', {'width': 2048, 'height': 1152}]
          only_resize: True
          keep_ratio: False
        - class: MakeICDARData
        - class: MakeSegDetectionData
        - class: NormalizeImage

​ 如果不考虑label,则其处理逻辑和训练逻辑有一点不一样,其把图片统一resize到指定的长度进行预测。

前面说过阈值图分支其实可以相当于辅助分支,可以联合优化各个分支性能。故在测试时候发现概率图预测值已经蛮好了,故在测试阶段实际上把阈值图分支移除了,只需要概率图输出即可。

后处理逻辑在structure/representers/seg_detector_representer.py,本文特色就是后处理比较简单,故流程为:

  1. 对概率图进行固定阈值处理,得到分割图
  2. 对分割图计算轮廓,遍历每个轮廓,去除太小的预测;对每个轮廓计算包围矩形,然后计算该矩形的预测score
  3. 对矩形进行反向shrink操作,得到真实矩形大小;最后还原到原图size就可以了
    def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
        '''
        _bitmap: single map with shape (H, W),
            whose values are binarized as {0, 1}
        '''

        assert len(_bitmap.shape) == 2
        bitmap = _bitmap.cpu().numpy()  # The first channel
        pred = pred.cpu().detach().numpy()
        height, width = bitmap.shape
        contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
        num_contours = min(len(contours), self.max_candidates)
        boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
        scores = np.zeros((num_contours,), dtype=np.float32)
        #对二值图计算轮廓,每个轮廓就是一个文本实例
        for index in range(num_contours):
            contour = contours[index].squeeze(1)
            #计算最小包围矩,得到points坐标
            points, sside = self.get_mini_boxes(contour)
            if sside < self.min_size:
                continue
            points = np.array(points)
            #利用points内部预测概率值,计算出一个score,作为实例的预测概率
            score = self.box_score_fast(pred, contour)
            if self.box_thresh > score:
                continue
            #shrink反向还原
            box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
            box, sside = self.get_mini_boxes(box)
            if sside < self.min_size + 2:
                continue
            box = np.array(box)
            if not isinstance(dest_width, int):
                dest_width = dest_width.item()
                dest_height = dest_height.item()
            #还原到原始坐标
            box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
            box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
            boxes[index, :, :] = box.astype(np.int16)
            scores[index] = score
        return boxes, scores

采用作者提供的训练好的权重进行预测,可视化预测结果如下所示:

DBNet详解_第19张图片
DBNet详解_第20张图片

论文中指标结果:

可以看出变形卷积和阈值图对整个性能都有比较大的促进作用。

DBNet详解_第21张图片

测试icdar2015数据结果:
DBNet详解_第22张图片

补充

语义分割中的loss function

cross entropy loss

用于图像语义分割任务的最常用损失函数是像素级别的交叉熵损失,这种损失会逐个检查每个像素,将对每个像素类别的预测结果(概率分布向量)与我们的独热编码标签向量进行比较。

假设我们需要对每个像素的预测类别有5个,则预测的概率分布向量长度为5:

DBNet详解_第23张图片

每个像素对应的损失函数为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OeRSbkGm-1610966579269)(https://www.zhihu.com/equation?tex=%5Ctext+%7Bpixel+loss%7D+%3D±%5Csum_%7Bclasses%7D+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29+%5C%5C)]

整个图像的损失就是对每个像素的损失求平均值。

特别注意的是,binary entropy loss 是针对类别只有两个的情况,简称 bce loss,损失函数公式为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DOMD9pF8-1610966579270)(https://www.zhihu.com/equation?tex=%5Ctext+%7Bbce+loss%7D+%3D±+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29%5C%5C)]

weighted loss

由于交叉熵损失会分别评估每个像素的类别预测,然后对所有像素的损失进行平均,因此我们实质上是在对图像中的每个像素进行平等地学习。如果多个类在图像中的分布不均衡,那么这可能导致训练过程由像素数量多的类所主导,即模型会主要学习数量多的类别样本的特征,并且学习出来的模型会更偏向将像素预测为该类别。

FCN论文和U-Net论文中针对这个问题,对输出概率分布向量中的每个值进行加权,即希望模型更加关注数量较少的样本,以缓解图像中存在的类别不均衡问题。

比如对于二分类,正负样本比例为1: 99,此时模型将所有样本都预测为负样本,那么准确率仍有99%这么高,但其实该模型没有任何使用价值。

为了平衡这个差距,就对正样本和负样本的损失赋予不同的权重,带权重的二分类损失函数公式如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PtRPVOrh-1610966579271)(https://www.zhihu.com/equation?tex=%5Ctext+%7Bpos_weight%7D+%3D+%5Cfrac%7B%5Ctext+%7Bnum_neg%7D%7D%7B%5Ctext+%7Bnum_pos%7D%7D+%5C%5C+%5Ctext+%7Bloss%7D+%3D±+%5Ctext+%7Bpos_weight%7D+%5Ctimes+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29%5C%5C)]

要减少假阴性样本的数量,可以增大 pos_weight;要减少假阳性样本的数量,可以减小 pos_weight。

focal loss

上面针对不同类别的像素数量不均衡提出了改进方法,但有时还需要将像素分为难学习和容易学习这两种样本。

容易学习的样本模型可以很轻松地将其预测正确,模型只要将大量容易学习的样本分类正确,loss就可以减小很多,从而导致模型不怎么顾及难学习的样本,所以我们要想办法让模型更加关注难学习的样本。

对于较难学习的样本,将 bce loss 修改为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6eZmpbcM-1610966579272)(https://www.zhihu.com/equation?tex=-+%281-y_%7Bpred%7D%29%5E%5Cgamma+%5Ctimes+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+y_%7Bpred%7D%5E%5Cgamma+%5Ctimes+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29+%5C%5C)]

其中的 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pOuXJT0r-1610966579273)(https://www.zhihu.com/equation?tex=%5Cgamma)] 通常设置为2。

举个例子,预测一个正样本,如果预测结果为0.95,这是一个容易学习的样本,有 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-w4wQuNnV-1610966579274)(https://www.zhihu.com/equation?tex=%281-0.95%29%5E2%3D0.0025)] ,损失直接减少为原来的1/400。

而如果预测结果为0.4,这是一个难学习的样本,有 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QQu9pL2M-1610966579276)(https://www.zhihu.com/equation?tex=%281-0.5%29%5E2%3D0.25)] ,损失减小为原来的1/4,虽然也在减小,但是相对来说,减小的程度小得多。

所以通过这种修改,就可以使模型更加专注于学习难学习的样本。

而将这个修改和对正负样本不均衡的修改合并在一起,就是大名鼎鼎的 focal loss:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-23Dr9Qbl-1610966579278)(https://www.zhihu.com/equation?tex=-+%5Calpha+%281-y_%7Bpred%7D%29%5E%5Cgamma+%5Ctimes+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+%281-%5Calpha%29+y_%7Bpred%7D%5E%5Cgamma+%5Ctimes+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29+%5C%5C)]

dice soft loss

Dice系数计算

语义分割任务中常用的还有一个基于 Dice 系数的损失函数,该系数实质上是两个样本之间重叠的度量。此度量范围为 0~1,其中 Dice 系数为1表示完全重叠。Dice 系数最初是用于二进制数据的,可以计算为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SGCRjBn0-1610966579279)(https://www.zhihu.com/equation?tex=Dice+%3D+%5Cfrac+%7B2+%7CA+%5Ccap+B%7C%7D%7B%7CA%7C+%2B+%7CB%7C%7D+%5C%5C)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TDJx3q9u-1610966579281)(https://www.zhihu.com/equation?tex=%7CA+%5Ccap+B%7C)] 代表集合A和B之间的公共元素,并且 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Y7BLfLdh-1610966579284)(https://www.zhihu.com/equation?tex=%7C+A+%7C)] 代表集合A中的元素数量(对于集合B同理)。

对于在预测的分割掩码上评估 Dice 系数,我们可以将 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wvhJvFjz-1610966579286)(https://www.zhihu.com/equation?tex=%7CA+%5Ccap+B%7C)] 近似为预测掩码和标签掩码之间的逐元素乘法,然后对结果矩阵求和。

DBNet详解_第24张图片

计算 Dice 系数的分子中有一个2,那是因为分母中对两个集合的元素个数求和,两个集合的共同元素被加了两次。

Dice loss

为了设计一个可以最小化的损失函数,可以简单地使用 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zTwxiWv4-1610966579288)(https://www.zhihu.com/equation?tex=1-Dice+)]。 这种损失函数被称为 soft Dice loss,这是因为我们直接使用预测出的概率,而不是使用阈值将其转换成一个二进制掩码。

Dice loss是针对前景比例太小的问题提出的,dice系数源于二分类,本质上是衡量两个样本的重叠部分。

对于二分类问题,一般预测值分为以下几种:

  • TP: true positive,真阳性,预测是阳性,预测对了,实际也是正例。
  • TN: true negative,真阴性,预测是阴性,预测对了,实际也是负例。
  • FP: false positive,假阳性,预测是阳性,预测错了,实际是负例。
  • FN: false negative,假阴性,预测是阴性,预测错了,实际是正例。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-I6TNYlF0-1610966579290)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JEcEJoM2ZGMkd2cjhxbHM2eG04Z2JBUURyUHIyT1VIN2ljWGVSWGdDckVjUVJteDBMTXI4bURBLzY0MA.png)]

这里dice coefficient可以写成如下形式:
d i c e = 2 T P 2 T P + F P + F N dice=\frac{2TP}{2TP+FP+FN} dice=2TP+FP+FN2TP

而我们知道:

可见dice coefficient是等同**「F1 score」,直观上dice coefficient是计算 与 的相似性,本质上则同时隐含precision和recall两个指标。可见dice loss是直接优化「F1 score」**。

对于神经网络的输出,分子与我们的预测和标签之间的共同激活有关,而分母分别与每个掩码中的激活数量有关,这具有根据标签掩码的尺寸对损失进行归一化的效果。

DBNet详解_第25张图片

对于每个类别的mask,都计算一个 Dice 损失:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-v2aiaP5D-1610966579294)(https://www.zhihu.com/equation?tex=1-+%5Cfrac+%7B2+%5Csum%5Climits_%7Bpixels%7D+y_%7Btrue%7D+y_%7Bpred%7D%7D%7B%5Csum%5Climits_%7Bpixels%7D+%28y_%7Btrue%7D%5E2+%2B+y_%7Bpred%7D%5E2%29%7D+%5C%5C)]

将每个类的 Dice 损失求和取平均,得到最后的 Dice soft loss。

梯度分析

从dice loss的定义可以看出,dice loss 是一种**「区域相关」**的loss。意味着某像素点的loss以及梯度值不仅和该点的label以及预测值相关,和其他点的label以及预测值也相关,这点和ce (交叉熵cross entropy)  loss 不同。

dice loss 是应用于语义分割而不是分类任务,并且是一个区域相关的loss,因此更适合针对多点的情况进行分析。由于多点输出的情况比较难用曲线呈现,这里使用模拟预测值的形式观察梯度的变化。

下图为原始图片和对应的label:[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xPLvYa7F-1610966579296)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JGMWR3blNGU1R5VEY4VFllNHN3SHBrR1FOM3JrWnRQamtYZGhoWjBydWo3RFFyamlibmowZ3lBLzY0MA.png)]

为了便于梯度可视化,这里对梯度求绝对值操作,因为我们关注的是梯度的大小而非方向。另外梯度值都乘以 保证在容易辨认的范围。

首先定义如下热图,值越大,颜色越亮,反之亦然:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BeCmQcqD-1610966579298)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0I1YXBtUDZNWGJaNDhocklkWmE3dHpGdEZKQmJwSFV6Q0tqTUhWRW5mQ3MyTmh1b2o4TTJTNVEvNjQw.png)]

预测值变化( 值,图上的数字为预测值区间):

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jY3HOS0H-1610966579299)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JiT3FrQzRMYXZJMThMbVdxQVNXTmE3STdjR2EwMm95cnB6cVhuZTRMNWhwajJDOWRySXUyS2cvNjQw.png)]

dice loss 对应 值的梯度:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pXdKApFF-1610966579301)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JrMXV2Sjhyem1qanMyMXdraWJzYkRBbktiSlVqTXFjaWFYSUt1VkJSaWFDd213TGZpYTMyanFUaWFuQS82NDA.png)]

ce loss 对应 值的梯度:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-i9VOLACe-1610966579302)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0I1aFdVUWliVk1CaWFtRzFxbXdocnp6ZUVqWTY1dmFhZWtlV05iMGVJcGJBUkNpYkoyMFdHUmliZmJRLzY0MA.png)]

可以看出:

  • 一般情况下,dice loss 正样本的梯度大于背景样本的,尤其是刚开始网络预测接进0.5的时候。说明dice loss 更具有指向性,更加偏向于正样本,保证有较低的FN。
  • 负样本(背景区域)也会产生梯度
  • 极端情况下,网络预测接进0或1时,对应点梯度值极小,dice loss 存在梯度饱和现象。此时预测失败(FN,FP)的情况很难扭转回来。不过该情况出现的概率较低,因为网络初始化输出接近0.5,此时具有较大的梯度值。而网络通过梯度下降的方式更新参数,只会逐渐削弱预测失败的像素点。
  • 对于ce loss,当前的点的梯度仅和当前预测值与label的距离相关,预测越接近label,梯度越小。当网络预测接近0或1时,梯度依然保持该特性。
  • 对比发现,训练前中期,dice loss 下正样本的梯度值相对于ce loss ,颜色更亮,值更大。说明dice loss对挖掘正样本更加有优势。

dice loss为何能够解决正负样本不平衡问题?

因为dice loss 是一个区域相关的loss。区域相关的意思就是,当前像素的loss不光和当前像素的预测值相关,和其他点的值也相关。dice loss的求交的形式可以理解为mask掩码操作,因此不管图片有多大,固定大小的正样本的区域计算的loss是一样的,对网络起到的监督贡献不会随着图片的大小而变化。从上图可视化也发现,训练更倾向于挖掘前景区域,正负样本不平衡的情况就是前景占比较小。而ce loss 会公平处理正负样本,当出现正样本占比较小时,就会被更多的负样本淹没。

dice loss背景区域能否起到监督作用?

可以的,但是会小于前景区域。和直观理解不同的是,随着训练的进行,背景区域也能产生较为可观的梯度。这点和单点的情况分析不同。这里求偏导,当t_i=0 时:

可以看出, 背景区域的梯度是存在的,只有预测值命中的区域极小时, 背景梯度才会很小.

dice loss 为何训练会很不稳定?

在使用dice loss时,一般正样本为小目标时会产生严重的震荡。因为在只有前景和背景的情况下,小目标一旦有部分像素预测错误,那么就会导致loss值大幅度的变动,从而导致梯度变化剧烈。可以假设极端情况,只有一个像素为正样本,如果该像素预测正确了,不管其他像素预测如何,loss 就接近0,预测错误了,loss 接近1。而对于ce loss,loss的值是总体求平均的,更多会依赖负样本的地方。

总结

dice loss 对正负样本严重不平衡的场景有着不错的性能,训练过程中更侧重对前景区域的挖掘。但训练loss容易不稳定,尤其是小目标的情况下。另外极端情况会导致梯度饱和现象。因此有一些改进操作,主要是结合ce loss等改进,比如:  dice+ce loss,dice + focal loss等,

soft IOU loss

前面我们知道计算 Dice 系数的公式,其实也可以表示为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2X1iB7aM-1610966579304)(https://www.zhihu.com/equation?tex=Dice+%3D+%5Cfrac+%7B2+%7CA+%5Ccap+B%7C%7D%7B%7CA%7C+%2B+%7CB%7C%7D+%3D+%5Cfrac+%7B2+TP%7D%7B2+TP+%2B+FP+%2B+FN%7D+%5C%5C)]

其中 TP 为真阳性样本,FP 为假阳性样本,FN 为假阴性样本。分子和分母中的 TP 样本都加了两次。

IoU 的计算公式和这个很像,区别就是 TP 只计算一次:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sN52A4f0-1610966579306)(https://www.zhihu.com/equation?tex=IoU+%3D+%5Cfrac+%7B%7CA+%5Ccap+B%7C%7D%7B%7CA%7C+%2B+%7CB%7C±+%7CA+%5Ccap+B%7C%7D+%3D+%5Cfrac+%7BTP%7D%7BTP+%2B+FP+%2B+FN%7D+%5C%5C)]

和 Dice soft loss 一样,通过 IoU 计算损失也是使用预测的概率值:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-efgoYMzv-1610966579307)(https://www.zhihu.com/equation?tex=loss+%3D±+%5Cfrac+%7B1%7D%7B%7CC%7C%7D+%5Csum%5Climits_c+%5Cfrac+%7B%5Csum%5Climits_%7Bpixels%7D+y_%7Btrue%7D+y_%7Bpred%7D%7D%7B%5Csum%5Climits_%7Bpixels%7D+%28y_%7Btrue%7D+%2B+y_%7Bpred%7D±+y_%7Btrue%7D+y_%7Bpred%7D%29%7D+%5C%5C)]

其中 C 表示总的类别数。

总结

交叉熵损失把每个像素都当作一个独立样本进行预测,而 dice loss 和 iou loss 则以一种更“整体”的方式来看待最终的预测输出。

预测值相关,和其他点的值也相关。dice loss的求交的形式可以理解为mask掩码操作,因此不管图片有多大,固定大小的正样本的区域计算的loss是一样的,对网络起到的监督贡献不会随着图片的大小而变化。从上图可视化也发现,训练更倾向于挖掘前景区域,正负样本不平衡的情况就是前景占比较小。而ce loss 会公平处理正负样本,当出现正样本占比较小时,就会被更多的负样本淹没。

dice loss背景区域能否起到监督作用?

可以的,但是会小于前景区域。和直观理解不同的是,随着训练的进行,背景区域也能产生较为可观的梯度。这点和单点的情况分析不同。这里求偏导,当t_i=0 时:

可以看出, 背景区域的梯度是存在的,只有预测值命中的区域极小时, 背景梯度才会很小.

dice loss 为何训练会很不稳定?

在使用dice loss时,一般正样本为小目标时会产生严重的震荡。因为在只有前景和背景的情况下,小目标一旦有部分像素预测错误,那么就会导致loss值大幅度的变动,从而导致梯度变化剧烈。可以假设极端情况,只有一个像素为正样本,如果该像素预测正确了,不管其他像素预测如何,loss 就接近0,预测错误了,loss 接近1。而对于ce loss,loss的值是总体求平均的,更多会依赖负样本的地方。

总结

dice loss 对正负样本严重不平衡的场景有着不错的性能,训练过程中更侧重对前景区域的挖掘。但训练loss容易不稳定,尤其是小目标的情况下。另外极端情况会导致梯度饱和现象。因此有一些改进操作,主要是结合ce loss等改进,比如:  dice+ce loss,dice + focal loss等,

soft IOU loss

前面我们知道计算 Dice 系数的公式,其实也可以表示为:

[外链图片转存中…(img-2X1iB7aM-1610966579304)]

其中 TP 为真阳性样本,FP 为假阳性样本,FN 为假阴性样本。分子和分母中的 TP 样本都加了两次。

IoU 的计算公式和这个很像,区别就是 TP 只计算一次:

[外链图片转存中…(img-sN52A4f0-1610966579306)]

和 Dice soft loss 一样,通过 IoU 计算损失也是使用预测的概率值:

[外链图片转存中…(img-efgoYMzv-1610966579307)]

其中 C 表示总的类别数。

总结

交叉熵损失把每个像素都当作一个独立样本进行预测,而 dice loss 和 iou loss 则以一种更“整体”的方式来看待最终的预测输出。

这两类损失是针对不同情况,各有优点和缺点,在实际应用中,可以同时使用这两类损失来进行互补。

你可能感兴趣的:(OCR,DBNet)