DBnet

一 论文阅读

论文地址https://arxiv.org/pdf/1911.08947.pdf,官方代码https://github.com/MhLiao/DB,下述参考代码https://github.com/BADBADBADBOY/pytorchOCR,详细解读可以参考https://blog.csdn.net/u010901792/article/details/112791647

1.1 本文是基于分割做文本检测,以前的做法就是基于概率图,然后使用一个固定的阈值,将图片进行二值化分割成前景背景,然后再通过后处理进行文本检测,但是会发现不同阈值对性能影响很大;该论文提出了自适应阈值,就是额外生成一个thresh_map阈值图,然后通过概率图和阈值图去生成最终的二值化图
1.2 对于传统的二值图计算如下公式,但是不可微分,不能进行端到端训练
在这里插入图片描述
这里采用 Differentiable binarization,计算公式如下,该公式是一个加参数的sigmoid函数在这里插入图片描述
1.3 该文章还探讨了新增的thresh_map,加入监督训练有助于提升性能

二 代码分析

整体代码参考https://github.com/BADBADBADBOY/pytorchOCR,主要是参考作者的多标签训练

1 数据处理

1.1 对于标注的数据,分为标注框polys,标签classes(这里一共分成两类),以及不参与训练的框会有个dontcare去区分;数据增强主要是随机缩放,随机旋转(这里是-10°到10°),随机翻转以及随机裁剪
1.2 对于随机裁剪的计算方式,会选取crop的范围(这里的原则是保证crop的范围内有一个完整的care的框就保留该crop,要不就重复直到满足),然后会把crop的最大边缩放到训练的尺寸大小640,短边padding到640大小;之后计算每个框是否在resize之后的crop范围内(框是否完全被crop包裹),如果不在范围内,就舍弃该框,当然如果前面dontcare为True的框在范围内,依然保留,dontcare依然为True
1.3 这里 MakeSegMap 生成 gt 以及 gt_classes , gt_mask对应的标签图,首先会计算框内的面积,如果框面积太小以及框的高宽小于规定的值,该框则设定为dontcare为True,对应的 gt 图值为[0,1](其中gt图的范围为对应的框范围内缩之后的大小),对应的gt_classes 图的值为[0,1,2](这里对应为2类,下图所示不同颜色是不同类别),对应的gt_mask就是dontcare为True的框对应的区域,值为[1,0]

DBnet_第1张图片 DBnet_第2张图片 DBnet_第3张图片

1.4 这里MakeBorderMap,生成thresh_map 以及thresh_mask,对于dontcare的框直接忽略;其中thresh_map是在原来的poly基础上进行外扩和内缩之后,将外扩的框所有区域作为thresh_mask,然后计算外扩框中间所有像素到poly框的距离,然后距离大于1的取1,之后得到的thresh_map值是用1减去距离值生成的,然后再归化到0.3-0.7之间(这个主要是因为后面模型生产的thresh图,是在sigmoid之后的,然后通过生成的thresh图与thresh_map图去做l1_loss计算)

DBnet_第4张图片 DBnet_第5张图片
1.5 至此数据生成就完成了,主要是label的生成,对应5张图,分别是gt,gt_classes,mask(也是上面的gt_mask),thresh_map,thresh_mask ## 2 模型结构 2.1 这里采用backbone是resnet18,然后对应的head是FPN,输入图片大小是$b*3*640*640$,经过FPN之后输出的是$b*256*160*160$ 2.2 然后经过SegDetectorMul模块,第一步是生成probability map,通过输入$x: b*256*160*160$,经过卷积+BN+ReLU+反卷积+BN+ReLU+Sigmoid,生成最终的probability map,大小是$b*1*640*640$;第二步还是同样的输入$x: b*256*160*160$,经过卷积+BN+ReLU+反卷积+BN+ReLU+反卷积,得到binary_classes图,大小是$b*2*640*640$;在训练的时候会在生成thresh_map,也是输入同上$x: b*256*160*160$,经过卷积+BN+ReLU+反卷积+BN+ReLU+反卷积+Sigmoid,得到最终thresh_map,大小是$b*1*640*640$,然后通过生成的probability map和thresh_map,得到thresh_binary,其实就是加参数的Sigmoid公式(如下图),公式里面的k取值50,主要是提供梯度的扩大

在这里插入图片描述
这里经过模型最终生成了4张图,probability map,binary_classes,thresh_map,thresh_binary

3 Loss计算

3.1 第一个loss是bce_loss(BalanceCrossEntropyLoss),通过probability map 与 gt 以及 gt_mask 进行计算,对于正负样本都要乘以 gt_mask 之后去计算(就是对于dontcare的框内部的loss不计算),这里先选取正负样本比例大致是1:3,计算二分类loss
3.2 第二个loss是针对分类的交叉熵loss,这里通过 binary_classes 与 gt_classes 去计算,计算loss只选取有class的框范围内的像素去计算loss,这个是多分类交叉熵loss
3.3 第三个loss是l1_loss,通过模型生成的 thresh_map 和 生成标签 对应的 thresh_map 以及 thresh_mask 去计算,这里也是对应thresh_mask内部的像素才计算loss
3.4 第四个loss是dice_loss,通过 thresh_binary 与 gt 以及 gt_mask 去计算,这里也是对应 gt_mask 之外的区域去计算loss
3.5 四个loss的比例是1 * dice_loss + 10 * l1_loss + 1 * bce_loss + 1 * class_loss

你可能感兴趣的:(深度学习,神经网络,机器学习)