PAN++ 代码讲解训练和预测

1、训练

首先就是配置数据集,这里我用的是ICDAR2015,配置文件使用pan_pp_r18_ic15_joint_train,数据集下载链接https://rrc.cvc.uab.es/?ch=4&com=downloads,将t配置文件修改为如下:

data = dict(
    batch_size=4,
    train=dict(
        type='PAN_PP_IC15',
        split='train',
        is_transform=True,
        img_size=736,
        short_size=736,
        kernel_scale=0.5,
        read_type='pil',
        with_rec=True
    ),
    test=dict(
        type='PAN_PP_IC15',
        split='test',
        short_size=736,
        read_type='pil',
        with_rec=True
    )
)

然后在trian中main配置如下:

parser.add_argument('--config',   default='config/pan_pp/pan_pp_r18_ic15_joint_train.py' ,help='config file path')

然后就可以快乐的训练了。如果需要识别的字典需要在pan_pp_ic15文件中get_vocabulary方法中修改如下所示:

def get_vocabulary(voc_type, EOS='EOS', PADDING='PAD', UNKNOWN='UNK'):
    if voc_type == 'LOWERCASE':
        voc = list(string.digits + string.ascii_lowercase)
    elif voc_type == 'ALLCASES':
        voc = list(string.digits + string.ascii_letters)
    elif voc_type == 'ALLCASES_SYMBOLS':
        voc = list(string.printable[:-6])
    else:
        raise KeyError('voc_type must be one of "LOWERCASE", '
                       '"ALLCASES", "ALLCASES_SYMBOLS"')

    # update the voc with specifical chars
    voc.append(EOS)
    voc.append(PADDING)
    voc.append(UNKNOWN)

    char2id = dict(zip(voc, range(len(voc))))
    id2char = dict(zip(range(len(voc)), voc))

    return voc, char2id, id2char

在数据处理中,首先,训练中的框不能多余200个,以及字符数量不大于32,如果大于则截断。然后创建gt_instance,gt_kernel,training_mask这三个维度和图片一样,gt_instance对标签中的框进行填充操作,gt_kernel对shrink的后的标签中的框进行填充,training_mask对忽略的进行填充对。要注意前面两个是背景为0,填充为1或者比1大的,而training_mask是1,然后填充0。这样主要是为了方便选择忽略的像素点。

然后对图片进行切割以及填充操作,使其变为736,736的长宽,这里需要注意他选取的是label边界作为随机值,所以会导致部分label会被切割,然后为了保证长宽不变使用copyMakeBorder进行填充处理。然后更新gt_kernel,gt_instance以及training_mask。然后记录label的左上和右下。然后对图片做了亮度饱和度,归一化操作。返回的数据如下:

        data = dict(
            imgs=img,#经过数据增强的图片
            gt_texts=gt_text, #与gt_instance区别是没有分类所有框均表示为1
            gt_kernels=gt_kernels, #缩放的核
            training_masks=training_mask, #标记为忽略的####背景的
            gt_instances=gt_instance,#所有框
            gt_bboxes=gt_bboxes,#每个框的左上和右下
        )

为了便于理解以上的,不太过书面化,所以我将这些数据输出。
原图为:


原图

原图经过操作后的变为如下所示:


数据增强的图片

显示所有的框:
gt_texts.jpg

shrink后的kernals的框:
gt_kernels.jpg

在这里我们明显看出kernals是将每个框作为一类。
gt_instance:


gt_instance.jpg

最为明显的就是training_mask,因为他原本就是以1作为背景,0来填充需要忽略的地方,但在旋转和填充时又将它上下边框变为0了:
training_mask.jpg

具体的还是建议看代码。好了数据处理部分就结束了,下面就可以来到令人开心的模型训练部分了。

首先是抽取通用的特征,使用resnet18和FPEM_v2进行卷积,之所以用FPEM_v2是因为resnet18只能抽取弱特征,包含的信息较少,所以需要FPEM进行上采样和下采样进行特征增强。在resnet18输出为184,92,46,23通道数,在FPEM中对184这个通道数进行上采样,做残差相加。重复两次,然后做特征上采样变为1,512,184,184。这样就得到了通用特征卷积了。然后就可以开始训练检测和识别了。

首先检测模块,将512通道的变为6通道的,然后进行上采样变为736。在这6个通道中,第一个通道表示标签的也就是gt_texts,这里需要注意的就是get_text和training_mask做了与操作,也就是将一些标签为###的部分忽略了,然后在与第一个通道做损失函数,然后通过0.5表示为阈值,大于的表示为正样本,小于的表示为负样本,并且保证正负样本的比例在1:3的样子,然后从负样本中选择置信度最高的。然后按照负样本中置信度最低的开始。这样正负样本的数据的over了。然后通过交并比来计算loss,然后求出均值。第二个通道kernals也是通过相同的方法进行计算,区别只是在于计算的是gt_kernels的loss。
重点来了,如何计算实例向量的loss,这也是对于我来说,卡我最久的。怎么计算实例或者什么叫实例呢?首先我们应该明白实例向量是使用后四层进行计算的。这四层首先通过training_mask和instance,以及kernals继续与操作,这样这四层中像素点表示的就是过滤后的kernals,然后通过view操作,将w和h合并为已为的,举个例子之前是,4,736,736,变化之后就变为4,541696。还记得kernals对每一个框的label都是不同的,这样每一个框就是一个实例向量,通过最简单的==,获取4层每一个框对应位置的像素点,然后求均值。如果这样不太懂,你可以认为每一个框都是一个四维向量,每一列表示一个实例向量,然后计算原图的实例对象通过instance,然后进行loss计算,这就是l_agg,如图下所示:



然后计算每一个类别向量之间的距离,l_reg就是通过上述两个实例向量进行线性计算得出的。这样检测部分的训练就结束了。代码如下:



识别就开始了,识别训练第一步就是通过标签进行裁剪这块和两阶段的识别很像,然后通过F.interpolate将其高宽变为8,32的,并获取对应的标签。这样的裁剪部分就结束了。识别这块就是直接使用多头注意力机制,编码解码,这块推荐transfomer这篇论文。然后在loss计算就是使用的cross_entropy,和直接使用==来计算准确率。
然后取均值就可以了。然后就是记录相关信息,保存权重就over了。

2、预测

重点注意在作业的配置文件中也就是pan_pp_r18_ic15_joint_train.py中对忽视的框或者识别的框的置信度都特别的高,我将他们改小了,具体的如下所示:
test_cfg = dict(
    min_score=0.8,
    min_area=260,
    min_kernel_area=2.6,
    scale=4,
    bbox_type='rect',
    result_path='outputs/submit_ic15_rec.zip',
    rec_post_process=dict(
        len_thres=3,
        # score_thres=0.95,
        score_thres=0.1,
        unalpha_score_thres=0.9,
        # unalpha_score_thres=0.9,
        # ignore_score_thres=0.93,
        ignore_score_thres=0.1,
        editDist_thres=2,
        voc_path=None #'./data/ICDAR2015/Challenge4/GenericVocabulary.txt'
    ),
)

配置config配置文件和权重路径

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Hyperparams')
    parser.add_argument('--config',default='config/pan_pp/pan_pp_r18_ic15_joint_train.py' ,help='config file path')
    parser.add_argument('--checkpoint', default='./checkpoints/checkpoint.pth.tar', nargs='?', type=str)
    parser.add_argument('--report_speed',default=None, action='store_true')
    args = parser.parse_args()
    main(args)

然后test文件会在控制台打印相关model数据,然后会对recognition_head这个模块中添加字典的相关信息,然后加载预训练权重。然后使用fuse_module对参数进行一些
归一化操作,并使用model_structure方法打印模型的相关参数。(PS:个人不太懂为什么要对已经训练好的模型权重参数进行归一化操作)。

在构建Dataset,这里我选的是PAN_PP_IC15类型的数据类型。他会在 getitem方法中判断是训练还是测试,使用对应的数据处理。在测试中比较简单在这里我选择使用’pil‘这种方式读取img路径,img_meta这个变量会将原图片的大小保存。
scale_aligned_short方法会将最短边变为736,并且保证长宽都能被32整除,并保存变换后的图片的长宽,对其进行归一化处理,并将归一化后的图片保存在img_meta变量中,将img_meta返回。
在将样本数据输入网络模型中,由于图片数据只有一边是736所以另一边是可变的,而在训练的时候长宽都是736,所以也可以使用resize修改数据操作,我将图片resize成了长宽都为736的图片后,发现其实比起不变的在性能上改变不大,这个最好自己测一下。
在dataloader得到数据后传入model模型中,首先进入resnet18,输出四个数据,通道数分别为64,128,256,512,由于只有一边是736,一边是可变的,当然也可以像我这样将图片直接resize为736,高宽就变为184,92,46,23。到这里resnet18的工作就结束了。然后将这四个数据分别进行归一化和relu的操作。
FPEM_v2的工作就开始了,他主要是增强reesnet18返回的数据,他将f4也就是1,512,23,23维度的进行上采样,分别变为其他三个通道的的维度,然后又进行下采样并且将高宽通过卷积减半,然后将原始的四个数据和经过上述处理后的数据相加,也就是类似残差相加,输出维度和之前一样,然后在进行一次。这样FPEM_v2的工作也就结束了。现在的数据的维度还是和resnet18输出的一样。
然后通过上采样,将四个数据的维度都变为f1,也就是1,128,184,184,然后进行堆叠,变为1,512,184,184,这样对于检测和识别通用的特征卷据就完成了。

首先是检测,通过det_head模块将通道数变为6。
这6个通道数中,第一个通道数表示为置信度,通过torch.sigmoid表示socre。第二层表示为text_mask通过对0的二值化,而kernals是用第二层通道数二值化并且乘以text_mask,然后float(俺也很懵逼,不知道为啥怎么写)。后面四层表示为实例向量,并且与text_mask相乘。然后通过pa这个方法就是对后面四层中的框进行膨胀操作尽量预测正确的label。然后将score,label上采样变为736。
然后获取label中的最大数也就是,当前预测了几个框。然后计算每个框的面积使其不小于要求的最小面积,小于则直接忽略。然后通过score获取置信度,如果小于也直接忽略。然后剩下的就是没问题的框了,然后计算左上和右下两个像素点。然后判断你选择的是rect还是多poly,然后分别进行操作并还原图片的大小。然后返回预测框bboxes_h表示左上右下的像素点位置,instances表示第几个框,以及label,score和bboxes。这样检测的任务就结束了,代码如下:

        results['bboxes'] = bboxes
        results['scores'] = scores
        if with_rec:
            results['label'] = label
            results['bboxes_h'] = bboxes_h
            results['instances'] = instances
        return results

然后将数据送入识别模块中。首先提取特征,将f变为736,对bboxes添加padding,获取框的个数。然后开始迭代,通过instance变量获取在识别模块中检测出来label特征图按照右上左下获取图片中对应label的像素点位置,然后在f中选出像素点。这样就是通过识别后的label来应用在f通用特征卷积上,这样就获取到了识别需要的像素点了,你可以理解为训练的裁剪步骤。然后将他的w和h变为8和32,然后就通过多头注意力机制的编码和解码得到结果。预测就结束了。

你可能感兴趣的:(PAN++ 代码讲解训练和预测)