PSENet-pytorch源码精读(一)config.py

PSENet-pytorch源码精读

  • PSENet的网络结构

PSENet的网络结构

PSENet是基于FPN的基础上实现的,使用ResNet作为backbone网络。
PSENet-pytorch源码精读(一)config.py_第1张图片
首先,从backbone中获取四个256通道的特征图( P 2 P_2 P2 P 3 P_3 P3 P 4 P_4 P4 P 5 P_5 P5)。为了进一步组合从低到高的语义特征,通过函数C来融合四个特征图,以获取具有1024个通道的特征图F,如下:
F = C ( P 2 , P 3 , P 4 , P 5 ) = P 2 ∣ ∣ U p × 2 ∣ ∣ U p × 4 ∣ ∣ U p × 8 F = C(P_2,P_3,P_4,P_5)= P_2||U_{p\times2}||U_{p\times4}|| U_{p\times8} F=C(P2,P3,P4,P5)=P2Up×2Up×4Up×8

其中“||”表示concatenation, U p × 2 U_{p\times2} Up×2 U p × 4 U_{p\times4} Up×4 U p × 8 U_{p\times8} Up×8分别表示2、4、8倍下采样。随后,将F总到Conv(3,3)-BN-ReLU层,将通道数由原来的1024减少到256通道。接下来,通过n个Conv(1,1)-up-Sigmoid层,产生n个分割结果 S 1 S_1 S1 S 2 S_2 S2、…、 S n S_n Sn。Conv,BN,ReLU和Up分别是指卷积[18],批归一化[15],整流线性单位[6]和上采样。

PSENet-pytorch源码精读(一)config.py_第2张图片
上图是PSE算法的流程图,其主要思想是来源于广度优先搜索(BFS)算法。假设我们有3个分割的结果 S = S 1 , S 2 , S 3 S={S_1,S_2,S_3} S=S1S2S3。首先,基于最小kernel的图S1(Figure 4(a)),4个不同的连接分量 C = c 1 , c 2 , c 3 , c 4 C={c_1,c_2,c_3,c_4} C=c1,c2,c3,c4作为初始化。Figure 4(b)中具有不同颜色的区域分别表示这些不同的连接组件。 到目前为止,我们已经检测到所有文本实例的中心部分(即最小内核)。 然后,我们通过合并 S 2 S_2 S2 S 3 S_3 S3中的像素来逐步扩展检测到的内核。 两次缩放的结果分别显示在Figure4( c c c)和Figure 4(d)中。 最后,我们提取Figure 4(d)中用不同颜色标记的连接组件作为文本实例的。

详细介绍看参考我的另一篇博客:文本检测算法:PSENet:Shape Robust Text Detection with Progressive Scale Expansion Network.

下面,我打算对PSENet的源码进行介绍:PSENet.pytorch.
使用的是win10+python37+pytorch1.x

工程代码结构如下:
PSENet-pytorch源码精读(一)config.py_第3张图片

1、config.py设定一些训练的参数

# config.py
# -*- coding: utf-8 -*-
# @Time    : 2019/1/3 17:40
# @Author  : zhoujun

# data config
# 训练集和测试集
#trainroot = '/dataset/icdar2015/train' #'/data2/dataset/ICD15/train'
trainroot = r'E:\ZhuoZhuangOCR\Paper\Latest\PSENet\PSENet-pytorch\dataset\icdar2015\train'
#testroot =  '/dataset/icdar2015/test' #'/data2/dataset/ICD15/test'
testroot = r'E:\ZhuoZhuangOCR\Paper\Latest\PSENet\PSENet-pytorch\dataset\icdar2015\test'
output_dir = 'output/psenet_icd2015_resnet152_4gpu_author_crop_adam_MultiStepLR_authorloss'
data_shape = 640

# train config
# 训练参数
gpu_id = '0'
workers = 12
start_epoch = 0
epochs = 600

# 8G运行内存,跑不了4个batch_size
train_batch_size = 2#4

# 学习率调节
lr = 1e-4
end_lr = 1e-7
lr_gamma = 0.1
lr_decay_step = [200,400]
weight_decay = 5e-4
warm_up_epoch = 6
warm_up_lr = lr * lr_gamma

display_input_images = True#False
display_output_images = False
display_interval = 10
show_images_interval = 50

pretrained = True
restart_training = True
checkpoint = ''

# net config
backbone = 'resnet152'
Lambda = 0.7
n = 6
m = 0.5
OHEM_ratio = 3
scale = 1
# random seed
seed = 2


def print():
    from pprint import pformat
    tem_d = {}
    for k, v in globals().items():
        if not k.startswith('_') and not callable(v):
            tem_d[k] = v
    return pformat(tem_d)

你可能感兴趣的:(文本检测)