本文提出了一种新的识别框架来学习场景文本图像的原始表示。将特征图中的元素建模为无向图的节点。使用池化聚合器和加权聚合器来学习原始表示,并通过图卷积网络转化为高级视觉文本表示。
针对全局特征聚合,提出了一种池化聚合器和一种加权聚合器。
池化聚合器通过两个卷积和一个全局平均池层从输入特征图中学习每个原始表示。通过这种方式,所有样本共享聚合权重,以从各种场景文本实例中学习内在的结构信息。
加权聚合器将输入特征图转换为热力图,用作聚合权重。
可视化文本表示是由图卷积网络(GCNs)根据原始表示生成的。每个可视化文本表示都用于表示要识别的字符。
构造了一个原始表示学习网络(PREN),利用视觉文本表示进行并行解码。
此外,由于视觉文本表示纯粹是从视觉特征中学习的,可以减轻基于注意力的方法的错位问题。将PREN集成到一个基于二维注意的编码器-解码器模型中,构建了一个名为PREN2D的框架。
通过使用坐标空间上的全局特征聚合来学习原始表示。通过这种方式,原始表示包含了输入图像的全局信息。设F∈R_{d0×h×w}是由CNN提取的特征图,其中h、w和d0分别是F的高度、宽度和信道数。
特征图中的元素被视为无向图的节点,即将F转换为特征矩阵X。作为要学习的原始表示的数量,特征聚合过程可以表示为
池化聚合
全局平均池层用于特征聚合,通过这种方式,所有样本共享聚合权重,以利用来自各种场景文本实例的内在结构信息。
加权聚合
动态地从输入特征中学习聚合权重。通过3×3卷积层曲线获得隐藏表示。另一个3×3卷积层和激活函数将输入特征图转换为热力图。聚合权重可以通过扁平化第一个热力图来得到。
由于原始表示包含输入图像的全局信息,因此可以从原始表示中提取文本信息。故通过原始表示的线性组合来生成视觉文本表示,然后生成一个全连接层。这里使用GCN来实现,其中系数矩阵起着与邻接矩阵类似的作用。
由于原语表示没有显式的图拓扑,系数矩阵在训练阶段被随机初始化和学习。长度为L的视觉文本表示y用来表示一个要被识别的字符。对于短于L的文本字符串,Y的多余部分用符号填充
PREN由特征提取模块和原始表示学习模块组成,使用三个池聚合器和三个加权聚合器从多尺度特征映射中学习原始表示。让P1和P2分别表示通过池聚合器和加权聚合器学习的原始表示。视觉文本表示Y1和Y2由两个GCNs获得,直接相加获得融合的视觉文本表示Y。全连接层用于将Y转换为logits以进行并行解码。
使用EffificientNet-B3作为特征提取模块,由七个移动反向瓶颈块(MBConv blocks)组成,输出为卷积图F_i,作为原始表示学习模块的输入。
对于每个卷积块输出的特征映射,同时使用池聚合器和加权聚合器来学习原始表示。两个GCN分别从原始表示P1和P2生成视觉文本表示Y1和Y2。Y1和Y2共同生成融合的视觉文本表示Y。每个字符的概率由Y通过一个全连接的层计算,然后计算softmax。
PREN输出的视觉文本表示也可以集成到基于注意力的编码器-解码器模型中
结合PREN模型和二维注意机制,构建PREN2D。特征提取模块由PREN和改进的Transfomer模型的编码器-解码器模块共享。PREN输出的视觉文本表示用于增强训练阶段地面真实文本的字符嵌入,或推理阶段以前的解码文本,可以在改进的Transfomer模型中为编码器解码器注意力的计算提供全局指导。
Transfomer Decoder用于文本转录。使用一个门控单元来结合视觉文本表示和字符嵌入。
可惜作者并没有给出PREN2D的源码,可惜可惜!!!
PREN和PREN2D都可以用预测和groudtruth之间的交叉熵进行端到端训练。通过在最后一个字符之后添加一个结束符号来生成groudtruth,并使用填充符号扩展到最大长度。
在推理阶段,PREN一步预测整个文本,而PREN2D递归地识别字符。解码结果中存在的第一个结束符号表示解码的结束。
场景文本识别的原始表示学习方法从原始表示中生成的视觉文本表示可以直接用于并行解码,也可以进一步集成到基于二维注意的编码器-解码器框架中,以提高识别性能。
对于EffificientNet-B3,总共是七个卷积块构成的,其中357被用作池化聚合的输入,他们的特征图分辨率更高,有更高的识别精度。作者也做了相应的消融实验。
接下来就是对应的池化聚合权重聚合,其输出分别作为GCN的输入,同时生成两个视觉文本表示YY2。融合的视觉文本表示Y是通过融合两种视觉文本表示Y1和Y2得到的。作者研究了三种融合策略,即求和、连接和门控单元。具有总和融合策略的模型在大多数测试集上都可以达到最佳的识别精度。
from Nets.EfficientNet import EfficientNet
from Nets.Aggregation import GCN, WeightAggregate, PoolAggregate
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self,
net_configs):
super(Model, self).__init__()
d_model = net_configs.d_model
n_r = net_configs.n_r
max_len = net_configs.max_len
n_class = net_configs.n_class
dropout = net_configs.dropout
self.cnn = EfficientNet.from_name('efficientnet-b3')
# pooling aggregators
# f3[b, 48, 8, 32], f5[b, 136, 4, 16], f7[b, 384, 2, 8]
self.agg_p1 = PoolAggregate(n_r, d_in=48, d_out=d_model // 3)
self.agg_p2 = PoolAggregate(n_r, d_in=136, d_out=d_model // 3)
self.agg_p3 = PoolAggregate(n_r, d_in=384, d_out=d_model // 3)
# weighted aggregators
self.agg_w1 = WeightAggregate(n_r=n_r, d_in=48, d_middle=4*48, d_out=d_model//3)
self.agg_w2 = WeightAggregate(n_r=n_r, d_in=136, d_middle=4*136, d_out=d_model//3)
self.agg_w3 = WeightAggregate(n_r=n_r, d_in=384, d_middle=4*384, d_out=d_model//3)
# GCNs
self.gcn_pool = GCN(d_in=d_model, n_in=n_r, d_out=d_model, n_out=max_len, dropout=dropout)
self.gcn_weight = GCN(d_in=d_model, n_in=n_r, d_out=d_model, n_out=max_len, dropout=dropout)
self.linear = nn.Linear(d_model, n_class)
self.max_len = max_len
self.d_model = d_model
def forward(self, input):
'''
:param input: images [b, 3, 64, 256]
:return logits: [b, L, n_class] probs of characters (before softmax)
'''
f3, f5, f7 = self.cnn(input)
rp1 = self.agg_p1(f3) # [b, nr, d / 3]
rp2 = self.agg_p2(f5) # [b, nr, d / 3]
rp3 = self.agg_p3(f7) # [b, nr, d / 3]
rp = torch.cat([rp1, rp2, rp3], dim=2) # [b, nr, d]
rw1 = self.agg_w1(f3)
rw2 = self.agg_w2(f5)
rw3 = self.agg_w3(f7)
rw = torch.cat([rw1, rw2, rw3], dim=2) # [b, nr, d]
y1 = self.gcn_pool(rp)
y2 = self.gcn_weight(rw)
y = 0.5 * (y1 + y2) # [b, L, d]
logits = self.linear(y)
return logits
代码组织如上图所示。