CTPN网络理解

本文主要对常用的文本检测模型算法进行总结及分析,有的模型笔者切实run过,有的是通过论文及相关代码的分析,如有错误,请不吝指正。

一下进行各个模型的详细解析

CTPN 详解

代码链接:https://github.com/xiaofengShi/CHINESE-OCR

CTPN是目前应用非常广泛的印刷体文本检测模型算法。

CTPN由fasterrcnn改进而来,可以看下二者的异同

网络结构 FasterRcnn CTPN
basenet Vgg16 ,Vgg19,resnet Vgg16,也可以使用其他CNN结构
RPN预测 basenet的predict layer使用CNN生成 basenet之后使用双向RNN使用FC生成
ROI 模型适用于目标检测,为多分类任务,包含ROI及类别损失和BOX回归 文本提取为二分类任务,不包含ROI及类别损失,只在RPN层计算目标损失及BOX回归
Anchor 一共9种anchor尺寸,3比例,3尺寸 固定anchor宽度,高度为10种
batch 每次只能训练一个样本 每次只能训练一个样本

根据ctpn的网络设计,可以看到看到ctpn一般使用预训练的vggnet,并且只用来检测水平文本,一般可以用来进行标准格式印刷体的检测,在目标框回归预测时,加上回归框的角度信息,就可以用来检测旋转文本,比如EAST模型。

代码分析

网络模型

直接看CTPN的网络代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
copy
class VGGnet_train(Network):
# 继承自NetWork,关与NetWork可以看这里:https://github.com/xiaofengShi/CHINESE-OCR/blob/master/ctpn/lib/networks/network.py
def __init__(self, trainable=True):
self.inputs = []
self.data = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data')
self.im_info = tf.placeholder(tf.float32, shape=[None, 3], name='im_info')
self.gt_boxes = tf.placeholder(tf.float32, shape=[None, 5], name='gt_boxes')
self.gt_ishard = tf.placeholder(tf.int32, shape=[None], name='gt_ishard')
self.dontcare_areas = tf.placeholder(tf.float32, shape=[None, 4], name='dontcare_areas')
self.keep_prob = tf.placeholder(tf.float32)
self.layers = dict({'data': self.data, 'im_info': self.im_info, 'gt_boxes': self.gt_boxes,'gt_ishard': self.gt_ishard, 'dontcare_areas': self.dontcare_areas})
self.trainable = trainable
self.setup()

def setup(self):
# 对于文本提议来说,类别为2,一类为为文字部分,另一类为背景
n_classes = cfg.NCLASSES
# anchor的初始尺寸,论文中使用的是16
anchor_scales = cfg.ANCHOR_SCALES
_feat_stride = [16, ]

# base net is vgg16
# 内部使用的函数
(self.feed('data')
.conv(3, 3, 64, 1, 1, name='conv1_1')
.conv(3, 3, 64, 1, 1, name='conv1_2')
.max_pool(2, 2, 2, 2, padding='VALID', name='pool1')
.conv(3, 3, 128, 1, 1, name='conv2_1')
.conv(3, 3, 128, 1, 1, name='conv2_2')
.max_pool(2, 2, 2, 2, padding='VALID', name='pool2')
.conv(3, 3, 256, 1, 1, name='conv3_1')
.conv(3, 3, 256, 1, 1, name='conv3_2')
.conv(3, 3, 256, 1, 1, name='conv3_3')
.max_pool(2, 2, 2, 2, padding='VALID', name='pool3')
.conv(3, 3, 512, 1, 1, name='conv4_1')
.conv(3, 3, 512, 1, 1, name='conv4_2')
.conv(3, 3, 512, 1, 1, name='conv4_3')
.max_pool(2, 2, 2, 2, padding='VALID', name='pool4')
.conv(3, 3, 512, 1, 1, name='conv5_1')
.conv(3, 3, 512, 1, 1, name='conv5_2')
.conv(3, 3, 512, 1, 1, name='conv5_3'))
# RPN
# 该层对上层的feature map进行卷积,生成512通道的的feature map
(self.feed('conv5_3').conv(3, 3, 512, 1, 1, name='rpn_conv/3x3'))
# 卷积最后一层的的feature_map尺寸为batch*h*w*512

# 原来的单层双向LSTM
(self.feed('rpn_conv/3x3').Bilstm(512, 128, 512, name='lstm_o'))
# bilstm之后输出的尺寸为(N, H, W, 512)

"""
和faster—rcnn相似,在ctpn的rpn网络中,使用双向lstm和全连接得到预测的
目标概率和回归框,在faster-rcnn中使用的是卷积的方式从basenet的最后一层生成
使用LSTM的输出来计算位置偏移和类别概率(判断是否是物体,不判断类别的种类)
输入尺寸为(N, H, W, 512) 输出尺寸(N, H, W, int(d_o))
可以将这一层当做目标检测中的最后一层feature_map
rpn_bbox_pred--对于h*w的尺寸上,每一anchor上生成4个位置偏移量
rpn_cls_score--对于h*w的尺寸上,每一anchor上生成2个置信度得分,判断是否为物体

"""
(self.feed('lstm_o').lstm_fc(512, len(anchor_scales) * 10 * 4, name='rpn_bbox_pred'))
(self.feed('lstm_o').lstm_fc(512, len(anchor_scales) * 10 * 2, name='rpn_cls_score'))

你可能感兴趣的:(CTPN网络理解)