PSENet笔记(一)

之前关注过文本检测,只了解到CTPN,现在开始看PSENet(Shape Robust Text Detection with Progressive Scale Expansion Network)

参考博客:https://mp.weixin.qq.com/s/-zMVO47AL1iKFmF16KsfOw

PSENet文本检测算法来自论文《Shape Robust Text Detection with Progressive Scale Expansion Network》,2018年7月发表于arxiv,已被CVPR 2019 接收。

参考博客:tensorflow_PSENet训练和测试

解决一些使用python2与python3不同造成的函数或模块找不到问题,以及模型路径问题。

参考:tensorflow版PSENet训练自己的数据及测试进行OCR文本检测,Linux和Windows详细复现过程

解决windows上pse.cpp的编译问题。

这样就可以在windows上跑通 tensorflow-PSENet的测试了。

我整理好了一份放在百度云盘上。  提取码:f55e

关于训练,有时间再详细笔记。

关于训练和测试,由于我比较懒就写了一个脚本,每次想执行的时候双击一下就可以执行了。

测试脚本:

python eval.py --test_data_path=data\images\ --gpu_list=0 --checkpoint_path=model\ --output_dir=data\result
pause

训练脚本:

python train.py --gpu_list=0 --input_size=512 --batch_size_per_gpu=8 --checkpoint_path=./train_model/ --training_data_path=./data/icdar2015/
pause

源码解读:

复现的论文神经网络部分使用的是tensorflow,广度优先搜索部分使用C++实现。
先从train.py开始看,103行定义了损失函数,在tower_loss函数中,构建了模型。模型的输出seg_maps是一个6通道的tensor,对应了论文中segmentation result。在train.py中没有引用到pse,pse在训练的过程中没有用到。
预测的过程在eval.py中。

def detect(seg_maps, timer, image_w, image_h, min_area_thresh=10, seg_map_thresh=0.9, ratio = 1):

其中,min_area_thresh是一个连通分量中至少有10个像素,seg_map_thresh是在返回的seg_map中,0.9以下的被变成0,以上的变为1,以此将seg_map变成二值图。在论文中出现的kernel就是图中变成1的部分。在detect函数中,调用了pse,这部分是使用C++实现的。python中调用C++使用了pybind11。

编译pse

pse文件夹中包含了pse的实现。其中include目录是pybind11的开源代码。广度优先的过程都在pse.cpp中。
在pse/__init__py中队pse.cpp进行了编译。
pybind11中,规定PYBIND!!_MODULE作为一个接口,写在C++文件中,编译的时候会将函数与python中的函数绑定。

PYBIND11_MODULE(pse, m){
    m.def("pse_cpp", &pse::pse, " re-implementation pse algorithm(cpp)", py::arg("label_map"), py::arg("Sn"), py::arg("c")=6);
}

第一个pse_cpp是python中绑定的函数名,第二个&pse::pse是在C++文件中待绑定的函数py::arg声明了参数以及默认值。在pse::pse中实现了一个广度优先搜索。

__init__.py中,对pse进行了一次封装。

label_num, label = cv2.connectedComponents(kernals[kernal_num - 1].astype(np.uint8), connectivity=4)

cv2.connectedComponents对模型求出来的最后一个kernel求了一次连通分量。label_num是图中连通分量的个数,label是带有标签的图,如果在连通分量里面,那个像素的值就是对应的连通分量编号,否则就是0。
接下来的for循环将小于10个像素的连通分量删除。

加载数据

在train.py中,数据生成使用的是:

data_generator = data_provider.get_batch(num_workers=FLAGS.num_readers,
                                                 input_size=FLAGS.input_size,
                                                 batch_size=FLAGS.batch_size_per_gpu * len(gpus))

get_batch中,类GeneratorEnqueuer使用的数据生成器是generator(**kargs),生成器的返回值有images, image_fns, seg_maps, training_masks,其中有用的是images, seg_maps, training_masks,image_fns是文件名,所以是没用的。对应的就是下面有data[0],data[2],data[3]没有data[1]。

 ml, tl, _ = sess.run([model_loss, total_loss, train_op], feed_dict={input_images: data[0],
                                                                                input_seg_maps: data[2],
                                                                                input_training_masks: data[3]})

读取标注的函数是data_provider.py中的load_annotation,在第280行调用了这个函数,这个函数就是读取标记用的,如果要兼容别的数据集需要修改这个函数。返回值text_polys是一个三维数组

PSENet笔记(一)_第1张图片

其中,每一层保存了一个多边形,由于数据集中只支持矩形,所以就只有四个点。第三个维度表示一张图片中存在多个文字块。text_tags是一个布尔型的数组。

text_polys, text_tags = load_annoataion(txt_fn)

在load_anotation之后调用check_and_validate_polys对text_polys和text_tags进行矫正。在这个函数中pyclipper.Area计算多边形内的面积,如果面积小于1,则舍去。使用pyclipper.Orientation使点的方向变成顺时针。
然后对图片进行随机放缩

im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)

标签中的最后数据仅用作是否是文本行的判断依据

            if label == '*' or label == '###' or label == '?':
                text_tags.append(True)
            else:
                text_tags.append(False)

crop_area会随机选择一块区域,如果有文字则为样本
否则为背景。
然后,generate_seg会将文字的矩形区域缩放成6种不同的大小,供金字塔结构使用

seg_map_per_image, training_mask = generate_seg((new_h, new_w), text_polys, text_tags,
                                                                     image_list[i], scale_ratio)

在generage_seg函数中,调用了函数shrink_poly这个函数用来将ground_truth进行不同比例的缩小(论文的3.3节label generation)

# seg map
            shrinked_polys = []
            if poly_idx not in ignore_poly_mark:
                shrinked_polys = shrink_poly(poly.copy(), scale_ratio[i])

模型实现:

在model.py中,首先建立金字塔特征:

feature_pyramid = build_feature_pyramid(end_points, weight_decay=weight_decay)

其中endpoints是resnet中几个特征图。
然后讲feature_pyramid进行concat,由于每一层的feature_pyramid的大小不一定一样,所以需要先进行缩放(unpool函数)
然后经过两个卷积层,得到seg_S_pred。

关于参数传递:

下面是取自 eval.py中的代码

tf.app.flags.DEFINE_string('test_data_path', None, '')
tf.app.flags.DEFINE_string('gpu_list', '0', '')
tf.app.flags.DEFINE_string('checkpoint_path', './', '')
tf.app.flags.DEFINE_string('output_dir', './results/', '')
tf.app.flags.DEFINE_bool('no_write_images', False, 'do not write images')

参考:tensorflow命令行参数

“DEFINE_xxx”函数带3个参数,分别是变量名称,默认值,用法描述

参考上面我自己写的测试脚本,可知,当有对应参数传递时,默认参数就会被覆盖。

 

我参考的是:https://github.com/liuheng92/tensorflow_PSENet

其中还有一些代码不太理解:

标签:

 if label == '*' or label == '###' or label == '?':
                text_tags.append(True)
            else:
                text_tags.append(False)

这里是说只有标签行以 “*” 或者“###”或者“?”为最后一个元素时才标记为True,那么对icdar2015标签数据中很多是以数字或者词组为标签数据的最后一个元素的,这个代码会将这些标记为False。我的认知是,有任何文本都应该标注为True,没有文本标注为False。当然上面这样写肯定是有道理的,毕竟我代码都没有看全,作者勿怪。

后来参考:https://github.com/whai362/PSENet/blob/master/dataset/icdar2015_loader.py

其中关于标签的设置比较符合我的认知。

        if gt[-1][0] == '#':
            tags.append(False)
        else:
            tags.append(True)

即标签行最后一个元素如果是以‘#’开头就认为该行没有文本,否则为文本行,即使这样也会有文本行以‘#’开头的,不过这种情况比较少见。

 

参考:PSENet源码阅读笔记

 

你可能感兴趣的:(文本检测)