TrOCR:基于Transformer的OCR介绍和使用

点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”


作者:Sovit Rath

编译:ronghuaiyang

导读

本文介绍了TrOCR的结构和使用方法,手把手从每一行代码教起。

光学字符识别(OCR)在过去几年中出现了一些创新。它对零售、医疗、银行和许多其他行业的影响是巨大的。尽管有着悠久的历史和一些最先进的模型,研究人员仍在不断创新。与深度学习的许多其他领域一样,OCR也看到了transformer 经网络的重要性和影响。今天,我们有像TrOCR(Transformer OCR)这样的模型,它们在准确性方面确实超过了以前的技术。

TrOCR:基于Transformer的OCR介绍和使用_第1张图片

在本文中,我们将介绍TrOCR,并重点讨论四个主题:

  • TrOCR的结构是什么

  • TrOCR系列包括哪些类型

  • 如何对TrOCR模型进行预训练

  • 如何使用TrOCR和Hugging Face进行推理

TrOCR结构

TrOCR由李等人在论文TrOCR:Transformer-based Optical Character Recognition with Pre-trained Models中介绍。

作者提出了一种背离传统的CNN和RNN的方法,他们使用视觉和语言transformer 模型来构建TrOCR架构。

TrOCR模型由两个阶段组成:

  • 编码器阶段由预训练的视觉transformer 模型组成。

  • 解码器阶段由预训练的语言transformer 模型组成。

由于其高效的预训练,基于transformer的模型在下游任务中表现得非常好。因此,作者选择了DeIT作为视觉转换器模型。对于解码器阶段,他们选择了RoBERTa或UniLM模型,这取决于TrOCR变体。

下图显示了使用TrOCR的简单OCR pipeline。

TrOCR:基于Transformer的OCR介绍和使用_第2张图片

图1,TrOCR 结构

在上图中,左块显示视觉transformer 编码器,右块显示语言transformer 解码器。以下是TrOCR推理阶段的简单分解:

  • 首先,我们将图像输入到TrOCR模型,该模型通过图像编码器。

  • 图像被分解成小块,然后通过多头注意力块。前馈块产生图像嵌入。

  • 然后这些嵌入进入语言transformer 模型。

  • 语言transformer 模型的解码器部分产生编码输出。

  • 最后,我们对编码输出进行解码,以获得图像中的文本。

需要注意的一点是,在进入视觉transformer模型之前,图像的大小调整为384×384分辨率。这是因为DeIT模型期望图像具有特定的大小。

TrOCR家族的模型

TrOCR 家族模型包括几个预训练和微调模型。

TrOCR 预训练模型

TrOCR家族中的预训练模型叫做第一阶段模型,这些模型在大量的生成数据上进行训练。数据集中包括百万张打印的文本行图像。

官方的代码仓库中包含了3个不同大小的预训练模型:

  • TrOCR-Small-Stage1

  • TrOCR-Base-Stage1

  • TrOCR-Large-Stage1

越大的模型效果越好,但是越慢。

TrOCR 微调模型

在预训练步骤之后,模型在IAM手写数据文本图像和SROIE打印收据数据集上进行微调。

IAM手写数据集包含了手写文本,在这个数据集上进行微调使得这个模型在手写文本的效果上好于其他模型。

类似的,SROIE数据集包含了几千个收据图像样本,微调之后,在打印文本上的效果会表现很好。

和预训练步骤的模型一样,手写和打印模型也包含了3个不同大小的模型:

  • TrOCR-Small-IAM

  • TrOCR-Base-IAM

  • TrOCR-Large-IAM

  • TrOCR-Small-SROIE

  • TrOCR-Base-SROIE

  • TrOCR-Large-SROIE

使用TrOCR和HuggingFace进行推理

Hugging Face上有TrOCR的所有模型,包括预训练步骤和微调步骤的。

我们会使用2个模型,一个手写微调模型,一个打印微调模型,来进行推理实验。

在Hugging Face上,模型的命名遵循trocr--规则。

举例说明,在IAM手写数据集训练的TrOCR的小模型叫做trocr-small-handwritten

我们使用trocr-small-printedtrocr-base-handwritten来进行推理。

安装依赖,导入,设置计算设备

首先要安装一些库:Hugging Face transformers, sentencepiece tokenizer.

!pip install -q transformers
!pip install -q -U sentencepiece

然后是下面的导入语句:

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
from tqdm.auto import tqdm
from urllib.request import urlretrieve
from zipfile import ZipFile
 
 
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import glob

我们需要用到urllibzipfile 来解压推理数据。

前向过程使用GPU和CPU都可以。

device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

帮助函数Helper Functions

下面的函数是如何下载和解压数据集。

def download_and_unzip(url, save_path):
    print(f"Downloading and extracting assets....", end="")
 
 
    # Downloading zip file using urllib package.
    urlretrieve(url, save_path)
 
 
    try:
        # Extracting zip file using the zipfile package.
        with ZipFile(save_path) as z:
            # Extract ZIP file contents in the same directory.
            z.extractall(os.path.split(save_path)[0])
 
 
        print("Done")
 
 
    except Exception as e:
        print("\nInvalid file.", e)
 
URL = r"https://www.dropbox.com/scl/fi/jz74me0vc118akmv5nuzy/images.zip?rlkey=54flzvhh9xxh45czb1c8n3fp3&dl=1"
asset_zip_path = os.path.join(os.getcwd(), "images.zip")
# Download if assest ZIP does not exists.
if not os.path.exists(asset_zip_path):
    download_and_unzip(URL, asset_zip_path)

上面的代码下载的图像包括:

  • 来自旧报纸的打印文本数据,用来跑打印模型。

  • 手写文本图像用来跑手写文本微调模型。

  • 将文本图像进行扭曲,用来测试 TrOCR model的极限情况。

接下来,我们用一个简单的函数来读取图像。

def read_image(image_path):
    """
    :param image_path: String, path to the input image.
 
 
    Returns:
        image: PIL Image.
    """
    image = Image.open(image_path).convert('RGB')
    return image

read_image() 函数的参数为图像路径,返回RGB格式的图像。

我们还写了一个函数来实现OCR的pipeline。

def ocr(image, processor, model):
    """
    :param image: PIL Image.
    :param processor: Huggingface OCR processor.
    :param model: Huggingface OCR model.
 
 
    Returns:
        generated_text: the OCR'd text string.
    """
    # We can directly perform OCR on cropped images.
    pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

ocr()这个函数需要下面几个参数:

  • image: RGB格式的PIL image数据。

  • processor: Hugging Face OCR pipeline需要先将图像转换为需要的格式。

  • model: 这个是Hugging Face OCR模型,接受预处理之后的图,给出解码输出。

在返回语句之前,有一个batch_decode() 函数,这个实际上就是将生成的IDs转换为输出文本,skip_special_tokens=True表示我们在输出中不需要特殊tokens,比如结束符和开始符。

最后的这个函数用来运行推理新的图像,包括了之前的函数,并显示了输出结果。

def eval_new_data(data_path=None, num_samples=4, model=None):
    image_paths = glob.glob(data_path)
    for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
        if i == num_samples:
            break
        image = read_image(image_path)
        text = ocr(image, processor, model)
        plt.figure(figsize=(7, 4))
        plt.imshow(image)
        plt.title(text)
        plt.axis('off')
        plt.show()

eval_new_data() 这个函数的参数为文件夹路径,样本数量,以及模型。

在打印文本上进行推理

我们加载TrOCR processor和模型来进行打印文本识别。

processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-printed')
model = VisionEncoderDecoderModel.from_pretrained(
    'microsoft/trocr-small-printed'
).to(device)

要加载TrOCR processor,我们需要使用from_pretrained模块,该模块接收HuggingFace的仓库路径,包含特定的模块。

TrOCR Processor做了哪些事情?

TrOCR模型是一个神经网络,不能直接处理图像,我们需要先将图像处理成合适的格式。TrOCR processor首先将图像缩放到 384×384 的分辨率,然后转换为归一化的tensor格式,然后再进行模型的推理。我们还可以指定tensort的格式,比如,我们转化为pt格式,表示是pytorch的tensor,我们还可以得到TensorFlow的格式。

同样的,我们使用VisionEncoderDecoderModel类加载预训练模型,在上面的代码中,我们加载了trocr-small-printed 模型,并将其加载到设备中。然后,我们调用eval_new_data()函数开始推理。

eval_new_data(
    data_path=os.path.join('images', 'newspaper', '*'),
    num_samples=2,
    model=model
)

运行上面的代码可以得到下面的输出:

c4c30202a0d260095048dbda585f0753.png

图2,TrOCR 在打印文本上的输出,有日期和数字

a83e1d083da31df0f71ab16ee9bef6e0.png

图3,TrCOR在模糊打印文本上的输出

图像上的文本表示了模型的输出,模型在模糊图像上的表现也很好,在第一张图像上,模型可以预测出所有的标点符号,空格,甚至是破折号。

在手写文本上进行推理

对于手写文本的推理,我们使用基础模型(比小模型大),我们首先加载手写TrOCR processor和模型。

processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained(
    'microsoft/trocr-base-handwritten'
).to(device)

我们的方法和打印文本模型一样,只是把仓库地址该成需要的模型。

在运行推理时,我们需要改变数据路径。

eval_new_data(
    data_path=os.path.join('images', 'handwritten', '*'),
    num_samples=2,
    model=model
)

这里是输出:

TrOCR:基于Transformer的OCR介绍和使用_第3张图片

图4,TrOCR在手写文本上的推理

这个例子很好的表现了TrOCR 在手写文本上的效果,可以正确识别出所有的字符,甚至是连写的字符。

072edd81f788149351cdd5cb5b2327de.png

图5,TrOCR 在手写字符上的推理

对于不同的手写风格,模型的效果也很好。视觉和语言模型的组合的威力显现。

测试TrOCR的极限

TrOCR并不是在所有类型的图像上都能表现很好。举例说明,小模型在弯曲文本上的效果不好,下面是几个例子:

TrOCR:基于Transformer的OCR介绍和使用_第4张图片

图6,TrOCR无法识别出弯曲文本

很明显,模型无法理解和识别出STATES 这个词,输出的是<

这是另外一个例子:

TrOCR:基于Transformer的OCR介绍和使用_第5张图片

图7,在竖向的文本上预测错误

这次,模型能预测出一个词,但是是错误的。

提升TrOCR的表现

从上面可以看到,TrOCR模型可能在某些场景下表现不好,这种限制同时来自视觉transformer和语言transformer的能力限制,需要一个经过弯曲文本图像训练过的视觉模型,以及能理解这种token的语言模型。

最好的方法是在弯曲文本数据集上微调 TrOCR 模型,我们会在SCUT-CTW1500数据集上进行训练。

总结

OCR 用简单的架构已经发展了很长时间,如今,TrOCR 为该领域带来了新的可能性。我们从介绍 TrOCR 开始,深入研究了它的架构。在此之后,我们介绍了不同的 TrOCR 模型及其训练策略。我们通过运行推理和分析结果来完成本文。

一个简单而有效的应用可以是数字化旧文章和报纸,这些文章和报纸很难人工清晰易读。

但是,在处理弯曲文本和自然场景中的文本时,TrOCR 也有其局限性。我们将在下一篇文章中更深入地探讨这一点,我们将在弯曲文本数据集上微调 TrOCR 模型并解锁新功能。

f0f0d9285d2bc45802a8da7dfe18a66e.png

—END—

英文原文:https://learnopencv.com/trocr-getting-started-with-transformer-based-ocr/

TrOCR:基于Transformer的OCR介绍和使用_第6张图片

请长按或扫描二维码关注本公众号

喜欢的话,请给我个在看吧

你可能感兴趣的:(transformer,ocr,深度学习,人工智能)