软件工程应用与实践(五):Paddle OCR文字识别器策略三

2021SC@SDUSC

 

目录

一、前情回顾

1.PP-OCR文字识别策略

2.本文策略简介

为什么要数据增强

数据增强测试流程

图像领域的数据增强

PP—OCR数据增强策略简介

二、BDA和TLA介绍

1.数据增强方法的总览介绍 

2.TLA相关

传统数据增强方法的不足

TLA算法的优点

TLA算法介绍

3.TLA的实现与结果

总体框架

增强模块

联合训练方案

三、主要代码分析

1.图像文字处理

2.文字识别处理

总结



一、前情回顾

1.PP-OCR文字识别策略

 策略的选用主要是用来增强模型能力和减少模型大小。下面是PP-OCR文字识别器所采用的九种策略:

  • 轻主干,选用采用 MobileNetV3 large x0.5 来权衡精度和效率;
  • 数据增强,BDA (Base Dataaugmented)和TIA (Luo et al. 2020);
  • 余弦学习率衰减,有效提高模型的文本识别能力;
  • 特征图辨析,适应多语言识别,进行向下采样 feature map的步幅修改;
  • 正则化参数,权值衰减避免过拟合;
  • 学习率预热,同样有效;
  • 轻头部,采用全连接层将序列特征编码为预测字符,减小模型大小;
  • 预训练模型,是在 ImageNet 这样的大数据集上训练的,可以达到更快的收敛和更好的精度;
  • PACT量化,略过 LSTM 层;

2.本文策略简介

  • 数据增强,BDA (Base Data Augmented)和TIA (Luo et al. 2020)

为什么要数据增强

  • 数据是机器学习的原材料,而大部分机器学习任务都是有监督任务,所以非常依赖训练数据,比如某个数据属于某一类是由于某种特征,通过这个过程,最终收获一个能预测一些规律的模型,从而使用这个模型去做一些预测。因此想要让模型有更好的效果,需要更大、质量更好的数据,当只有少类样本的时候,就需要数据增强来提高数据量。
  • 单纯使用人工标注数据费时费力,而且当需要标注大量数据时,很多时候不可避免就会因为各种人为因素导致标注错误,从而使数据质量降低
  • 单纯使用人工标注数据费时费力,而且当需要标注大量数据时,很多时候不可避免就会因为各种人为因素导致标注错误,从而使数据质量降低

数据增强测试流程

软件工程应用与实践(五):Paddle OCR文字识别器策略三_第1张图片

图像领域的数据增强

  通过对原始图片进行平移、旋转、裁剪、遮挡、反转、放缩、灰度等处理,保证原始图片类别不变的前提下,生成大量数据。除了这些简单的方法,还有很多机器学习的方法,比如对抗网络GAN的方式生成很多仿真的图片。

PP—OCR数据增强策略简介

       图像文字识别分为两种不同场景,文档文本和场景文本。场景文本是指如图所示的自然场景中的文本,它通常会因为一些因素而发生巨大的变化,如视角、 缩放、弯曲、混乱、字体、多语言、模糊、光照等。文档文本在实际应用中更常见。但存在高密度、长文本等不同的问题需要解决,文档图像文本识别往往需要对结果进行结构化处理。同时,手写文本识别因为手写风格过多、手写标注文本图像采集成本高等,对图像文本识别同样是很大的挑战 。

  Paddle OCR为了实现更加精确的识别,采用数据增强和强化网络相结合的手法。除了传统的旋转缩放和透视等增强方法,PP-OCR在数据增强和网络优化的孤立过程之间架起了桥梁,为识别网络生成更加适合训练的训练样本。除了本文将从技术介绍和代码解读两个方面来介绍BDA和TLA。



二、BDA和TLA介绍

1.数据增强方法的总览介绍 

       在图像分类任务中,数据增强是一种常用的正则化方法,同时在文字识别等方面已成为提升模型性能的必须步骤。从AlexNet到EfficientNet都可以看到数据增强的身影。数据增强的方法也由传统的裁剪、旋转、镜像等方式逐渐过渡到当前火热的AutoAug、RandAug等基于NAS搜索的方法。

       下面将数据增强方法进行简单的分类:

  • 标准数据增广:泛指深度学习前期或更早期的一些常用数据增广方法; 
  • 图像变换类:泛指基于NAS搜索到的一组变换组合,包含AutoAugment、RandAugment、Fast AutoAugment、Faster AutoAugment、Greedy Augment等;

  • 图像裁剪类:泛指深度学习时代提出的一些类似dropout的数据增广方法,包含CutOut、RandErasing、HideAndSeek、GridMask等;

  • 图像混叠类:泛指在batch层面进行的操作,包含Mixup、Cutmix、Fmix等

       现有的几何数据增强包含平移、旋转、裁剪、遮挡、反转、放缩、灰度等处理。PP—OCR主要采用了Distort特效、光学拉伸成像、perspective属性3D变形等来对输入的图像进行处理以达到数据增强的效果。 具体包含的处理方式有双线性插值(Bilinear interpolation)、添加高斯噪声、图像数据增强防抖 jitter、等分切割(crop()函数的运用)等。


2.TLA相关

传统数据增强方法的不足

  传统的数据增强方法,每个样本的随机增广策略相同,忽略了样本之间差异和网络的优化过程,在人工控制的静态样本下,增广可能产生很多对训练无用的样本,很难满足动态优化的要求。TLA是一种可学习的增强方法,能够自适应任务。

  常见的几何标记,如反转、旋转、缩放、透视等,通常对单个物体识别有用,然而文本图像包含多个字符,对文本多样性没有显著贡献。

TLA算法的优点

  • 是第一个专门为序列类字符设计的增强方法;
  • 联合优化数据增强和识别模型的框架,增强样本是通过自动学习产生的,对训练模型更加有效,提出的框架是端到端的,无需任何微调;

  各种场景文本和手写文本的基准上进行的大量实验表明。增强和联合学习方法显著提高了识别器性能。

TLA算法介绍

场景文本识别技术部分:

  场景文本图像中有多个字符,比单个字符识别更困难。场景文本识别方法可以分为以下两种类型:基于定位和无分割:
  前者尝试定位字符的位置,识别它们并将所有字符分组为文本字符串。 后者得益于深度神经网络的成功,将文本识别建模为序列识别问题。在卷积神经网络(CNNs)的基础上应用了递归神经网络(RNNs)来处理序列类对象的空间依赖性。此外,通过注意机制解决了序列到序列的映射问题。

  对不规则文本的识别,TLA提出了校正网络来消除失真,降低识别难度。迭代去除透视失真和文本线曲率。通过对每个字符使用更多的几何约束和监督,给出了准确的文本形状描述。但不规则场景文本识别仍然是一个具有挑战性的问题。

手写文本识别技术部分:

  早期的方法使用混合隐马尔可夫模型,将单词图像和文本字符串嵌入到公共向量子空间中,将识别任务转换为最近邻问题。
  在深度学习时代,先使用cnn再使用rmn提取特征,取得了较好的效果。 通过提出个序列到序列的域自适应网络来解决笔迹风格的多样性问题。对抗地扭曲了中间特征空间,以缓解一些稀疏训练数据集缺乏变化的问题。因为书写风格的多样性,手写文本的识别仍然是具有挑战性的问题。

总体:

  首先, 可学习代理预测移动状态的分布,以创建更难的训练样本。然后,增强模块根据随机运动状态和预测运动状态分别生成增强样本。对样本的识别难度由识别网络来衡量。最后,agent将难度增加的移动状态作为引导并更新自身。如下图所示:

软件工程应用与实践(五):Paddle OCR文字识别器策略三_第2张图片

 

3.TLA的实现与结果

总体框架

  如上图所示,所提出的框架由三个主要模块组成:agent网络、增强模块和识别网络。

  首先,我们在图像上初始化一组自定义基准点。 将智能体网络预测的运动状态和随机生成的运动状态反馈给增强模块。移动状态表示一组自定义基准点的移动。然后增强模块以图像为输入,分别进行基于运动状态的变换。识别器预测增强图像上的文本字符串。最后,测量识别度在编辑距离的度量下,增强图像的难度。智能体从增加难度的移动状态中学习,探索识别器的弱点。
 


增强模块

  文本增加过程

  将图像分成3个patch (N = 3),移动半径限制在10(R = 10)。红点表示控制点

软件工程应用与实践(五):Paddle OCR文字识别器策略三_第3张图片

 

 

  弹性变换

  软件工程应用与实践(五):Paddle OCR文字识别器策略三_第4张图片

 

  弹性(相似)和刚性转换的比较。所有图像上基准点的运动是相同的。刚性转换保持相对形状(对于一般对象是真实的),但是文本图像增强需要对每个字符进行更灵活的变形。因此,弹性(相似度)变换更适合于文本图像增强。

联合训练方案

  agent网络的学习方案如下图所示。

  首先,可学习代理预测一个移动状态分布,旨在创建一个更难的训练样本。随机运动状态也被馈送到增益模组。然后增强模块根据两种运动状态分别生成增强样本。然后, 识别网络将扩增样本作为输入,预测文本字符串。这对样本的难度是通过地面真实值和预测文本字符串之间的编辑距离来衡量的。最后,agent以难度增加的移动状态为导向,进行自我更新。统一的框架是端到端可培训的。

软件工程应用与实践(五):Paddle OCR文字识别器策略三_第5张图片

 

 

三、主要代码分析

1.图像文字处理

代码位置:

主要代码段(三大操作:distort,stretch,perspective):

import numpy as np
from .warp_mls import WarpMLS


def tia_distort(src, segment=4):
    img_h, img_w = src.shape[:2]

    cut = img_w // segment
    thresh = cut // 3

    src_pts = list()
    dst_pts = list()

    src_pts.append([0, 0])
    src_pts.append([img_w, 0])
    src_pts.append([img_w, img_h])
    src_pts.append([0, img_h])

    dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
    dst_pts.append(
        [img_w - np.random.randint(thresh), np.random.randint(thresh)])
    dst_pts.append(
        [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
    dst_pts.append(
        [np.random.randint(thresh), img_h - np.random.randint(thresh)])

    half_thresh = thresh * 0.5

    for cut_idx in np.arange(1, segment, 1):
        src_pts.append([cut * cut_idx, 0])
        src_pts.append([cut * cut_idx, img_h])
        dst_pts.append([
            cut * cut_idx + np.random.randint(thresh) - half_thresh,
            np.random.randint(thresh) - half_thresh
        ])
        dst_pts.append([
            cut * cut_idx + np.random.randint(thresh) - half_thresh,
            img_h + np.random.randint(thresh) - half_thresh
        ])

    trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
    dst = trans.generate()

    return dst


def tia_stretch(src, segment=4):
    img_h, img_w = src.shape[:2]

    cut = img_w // segment
    thresh = cut * 4 // 5

    src_pts = list()
    dst_pts = list()

    src_pts.append([0, 0])
    src_pts.append([img_w, 0])
    src_pts.append([img_w, img_h])
    src_pts.append([0, img_h])

    dst_pts.append([0, 0])
    dst_pts.append([img_w, 0])
    dst_pts.append([img_w, img_h])
    dst_pts.append([0, img_h])

    half_thresh = thresh * 0.5

    for cut_idx in np.arange(1, segment, 1):
        move = np.random.randint(thresh) - half_thresh
        src_pts.append([cut * cut_idx, 0])
        src_pts.append([cut * cut_idx, img_h])
        dst_pts.append([cut * cut_idx + move, 0])
        dst_pts.append([cut * cut_idx + move, img_h])

    trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
    dst = trans.generate()

    return dst


def tia_perspective(src):
    img_h, img_w = src.shape[:2]

    thresh = img_h // 2

    src_pts = list()
    dst_pts = list()

    src_pts.append([0, 0])
    src_pts.append([img_w, 0])
    src_pts.append([img_w, img_h])
    src_pts.append([0, img_h])

    dst_pts.append([0, np.random.randint(thresh)])
    dst_pts.append([img_w, np.random.randint(thresh)])
    dst_pts.append([img_w, img_h - np.random.randint(thresh)])
    dst_pts.append([0, img_h - np.random.randint(thresh)])

    trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
    dst = trans.generate()

    return dst

2.文字识别处理

代码位置:

 主要代码段:

class RecAug(object):
    #……
        return data

class ClsResizeImg(object):
    #……
        return data


class RecResizeImg(object):
   #……
        return data


class SRNRecResizeImg(object):
    #……
        return data

#各种图像操作函数
def resize_norm_img(img, image_shape):
  

def resize_norm_img_chinese(img, image_shape):
   

def resize_norm_img_srn(img, image_shape):


def srn_other_inputs(image_shape, num_heads, max_text_length):
 

def flag():
    

def cvtColor(img):


def blur(img):


def jitter(img):


def add_gasuss_noise(image, mean=0, var=0.1):
    


def get_crop(image):


#配置类
class Config:
    """
    Config
    """

#各种配置函数
def rad(x):
   

def get_warpR(config):
    

def get_warpAffine(config):
    

def warp(img, ang, use_tia=True, prob=0.4):
    




总结

  以上是今天PP-OCR文字识别模型的数据增强策略相关介绍。之后将会继续介绍PP-OCR文字识别模型的其他策略

你可能感兴趣的:(ar,paddlepaddle,深度学习)