项目信息Program Info
项目要求
项目方案
在自然场景下进行手写汉字识别主要分为两个步骤:手写汉字检测和手写汉字识别。如下图所示,其中,手写汉字检测主要目标是从自然场景的输入图像中寻找到手写汉字区域,并将手写汉字区域从原始图像中分离出来;手写汉字识别的主要目标是从分离出来的图像中准确地识别出该手写汉字的含义。
对于手写汉字检测,考虑采用CTPN算法,CTPN是在ECCV 2016中论文 Detecting Text in Natural Image with Connectionist Text Proposal Network (https://arxiv.org/abs/1609.03605)
中提出的一种文字检测算法。CTPN是在Faster RCNN的基础上结合CNN与LSTM深度网络,能有效的检测出复杂场景的横向分布的文字。对于手写汉字识别考虑使用CNN+RNN+CTC(CRNN+CTC)方法进行识别。CNN用于提取图像特征,RNN使用的是双向的LSTM网络(BiLSTM),用于在卷积特征的基础上继续提取文字序列特征。使用CTCLoss可以解决输出和label长度不一致的问 题,而不用手动去严格对齐。
进度安排
准备工作
由于之前深度学习框架只接触Pytorch和Tensorflow,而项目需要使用MindSpore进行搭建,因此我首先对其进行了解。
MindSpore简介
MindSpore是一种适用于端边云场景的新型开源深度学习训练/推理框架。MindSpore提供了友好的设计和高效的执行,旨在提升数据科学家和算法工程师的开发体验,并为Ascend AI处理器提供原生支持,以及软硬件协同优化。具有易开发、高效执行、全场景覆盖三大特性,其中易开发表现为API友好、调试难度低,高效执行包括计算效率、数据预处理效率和分布式训练效率,全场景则指框架同时支持云、边缘以及端侧场景。
同时,MindSpore作为全球AI开源社区,致力于进一步开发和丰富AI软硬件应用生态。
此外对于使用其他深度学习框架的学习者,官方文档中也给出了Pytorch和Tensorflow与Mindspore的算子映射表
(https://www.mindspore.cn/docs/migration_guide/zh-CN/master/api_mapping.html ),
我们可以极其便捷的进行算法迁移。
数据集处理
由于我们任务分为文本检测和文本识别两个部分,因此我们需要数据集必须能同时满足两种任务需求,经过考察后我们选择由中科院自动化所模式识别国家重点实验室搭建的
(CASIA-HWDB)汉字识别数据集
(http://www.nlpr.ia.ac.cn/databases/handwriting/Home.html)手写样本由1020名作者在纸上书写,主要包括独立的字符和连续汉字。离线数据集由6个子数据集组成,3个为手写的独立字符(DB1.0– 1.2),3个为手写汉字(DB2.0–2.2),独立字符中包含3.9M个样本,分为7356类,其中有7185个汉字和171个符号,手写文本共有5090页和1.35M个汉字。手写汉字样本如下图所示。
脱机文本的数据格式为 .dgrl
需要对该文本进行解析,转换为训练需要的图片格式和label格式。此外,由于该数据集保存的是每一行的文本图片,为了进行文本检测任务我们需要将每一行拼接成一页的文本图片。
因此数据集处理过程可以分为两步:转换数据格式获得文本行数据,拼接文本行获得文本页数据
转换数据格式
.dgrl按照如下图所示进行存储,每一张图对应一个DGRL文件,大部分内容都有固定的长度,部分内容长度不固定 但是也能通过其他数据推导出来,我们可以通过访问文件特定位置的数据得到我们需要的内容:行文本标注,行图像。获取到重要的文本和图像信息即可。
其中通过文件头部分,可以得到文件头部长度和单个字符的长度,因为文件头部中任意长度的信息,可以通过文件头 部长度直接跳转到图像信息部分,通过单个字符长度可以读取文本信息。图像信息中,可以得到图像的高度、宽度和 行数量,均是固定长度。之后在行文本信息中获取行内字符数量,行文本标注,行图像高度、宽度和图像像素。
使用二进制方式打开文件进读取.dgrl 文件
f = open(dgrl, 'rb')
之后使用numpy进行依次读取,注意是一个一个Byte依次读取,需要指定读取的格式和数量
import numpy as np
np.fromfile(f, dtype='uint8', count=4)
一般 dtype 都选择 uint8,count 需要根据上图结构中的长度 Length 做相应变化。
要注意的一个地方是:行文本标注读取出来以后,是一个 int 列表,要把它还原成汉字,一个汉字占用两个字节(具体由 code_length 决定),使用 struct 将其还原:
struct.pack('I', i).decode('gbk', 'ignore')[0]
上面的 i
就是提取出来的汉字编码,解码格式为 gbk
,有些行文本会有空格,解码可能会出错,使用 ignore
忽略。
得到的坐标为每个矩形框的坐标。保存文件为 .jpg
的图像格式
详细代码如下所示
import struct
import os
import cv2 as cv
import numpy as np
def read_from_dgrl(dgrl):
if not os.path.exists(dgrl):
print('DGRL not exis!')
return
dir_name,base_name = os.path.split(dgrl)
label_dir = dir_name+'_label'
image_dir = dir_name+'_images'
if not os.path.exists(label_dir):
os.makedirs(label_dir)
if not os.path.exists(image_dir):
os.makedirs(image_dir)
with open(dgrl, 'rb') as f:
# 读取表头尺寸
header_size = np.fromfile(f, dtype='uint8', count=4)
header_size = sum([j<<(i*8) for i,j in enumerate(header_size)])
# print(header_size)
# 读取表头剩下内容,提取 code_length
header = np.fromfile(f, dtype='uint8', count=header_size-4)
code_length = sum([j<<(i*8) for i,j in enumerate(header[-4:-2])])
# print(code_length)
# 读取图像尺寸信息,提取图像中行数量
image_record = np.fromfile(f, dtype='uint8', count=12)
height = sum([j<<(i*8) for i,j in enumerate(image_record[:4])])
width = sum([j<<(i*8) for i,j in enumerate(image_record[4:8])])
line_num = sum([j<<(i*8) for i,j in enumerate(image_record[8:])])
print('图像尺寸:')
print(height, width, line_num)
# 读取每一行的信息
for k in range(line_num):
print(k+1)
# 读取该行的字符数量
char_num = np.fromfile(f, dtype='uint8', count=4)
char_num = sum([j<<(i*8) for i,j in enumerate(char_num)])
print('字符数量:', char_num)
# 读取该行的标注信息
label = np.fromfile(f, dtype='uint8', count=code_length*char_num)
label = [label[i]<<(8*(i%code_length)) for i in range(code_length*char_num)]
label = [sum(label[i*code_length:(i+1)*code_length]) for i in range(char_num)]
label = [struct.pack('I', i).decode('gbk', 'ignore')[0] for i in label]
print('合并前:', label)
label = ''.join(label)
label = ''.join(label.split(b'\x00'.decode())) # 去掉不可见字符 \x00,这一步不加的话后面保存的内容会出现看不见的问题
print('合并后:', label)
# 读取该行的位置和尺寸
pos_size = np.fromfile(f, dtype='uint8', count=16)
y = sum([j<<(i*8) for i,j in enumerate(pos_size[:4])])
x = sum([j<<(i*8) for i,j in enumerate(pos_size[4:8])])
h = sum([j<<(i*8) for i,j in enumerate(pos_size[8:12])])
w = sum([j<<(i*8) for i,j in enumerate(pos_size[12:])])
# print(x, y, w, h)
# 读取该行的图片
bitmap = np.fromfile(f, dtype='uint8', count=h*w)
bitmap = np.array(bitmap).reshape(h, w)
# 保存信息
label_file = os.path.join(label_dir, base_name.replace('.dgrl', '_'+str(k)+'.txt'))
with open(label_file, 'w') as f1:
f1.write(label)
bitmap_file = os.path.join(image_dir, base_name.replace('.dgrl', '_'+str(k)+'.jpg'))
cv.imwrite(bitmap_file, bitmap)
结果如下图所示
可以发现每张图仅为一行文本数据,这样无法进行文本识别,因此需要将文件进行拼接为整页的文本格式。
拼接文本行数据
根据上一步得到的数据可以发现,每一完整页的汉字前缀为 006-P16_*
的形式,下划线后表示每行,如果需要拼接为整页只要将相同前缀按照顺序拼接即可。下面对于图片和label的拼接和生成进行分别说明。
首先对图片拼接进行说明。上一步得到的每个行图片height和width不同,在拼接时需要进行调整。对width而言,由于是从上到下拼接,width需要保持一致,因此,取每个行图片width的最大值,其他小于max_width的图片到扩充到最大值,均扩充为白色。对height而言,由于段首和段位的长度明显要小于段间的长度,如果都pad到行图片的前端或后端显然不合适,这时候做一个简单的判断,如果是开头就pad到行图片的前端,如果是结尾或段中就pad到行图片的后端。最后将pad成整页的图片在外围在pad上白色。
对于label进行生成,由行级别的bbox坐标和字符两个部分组成。先生成bbox的坐标再将每个行图片的label读取写入新的page level的label中。bbox 的坐标为每个矩形框四个点的坐标。最终生成的结果如下所示
手写汉字图片
标签
628,1000,2519,1000,2519,1085,628,1085,2006年8月,国际天文学联合会大会正式通过决议,将冥王星降级,
500,1085,2500,1085,2500,1175,500,1175,与其他类似的一些星体统一定义为“矮行星”。当时,天文学家认为冥
500,1175,2519,1175,2519,1269,500,1269,王星应该是矮行星中的“老大”。而最新的天文观测证实,冥王星的“老
500,1269,1010,1269,1010,1343,500,1343,大”头衔也将不保。
650,1343,2519,1343,2519,1445,650,1445,美国加利福尼亚理工学院天文学家迈克尔·布朗等人定于15日出版的
500,1445,2508,1445,2508,1557,500,1557,美国《科学》杂志上报告说,他们在研究矮行星厄里斯的卫星“迪丝诺美
500,1557,2516,1557,2516,1666,500,1666,亚”时,利用设在美国夏威夷的凯克大型望远镜和太空中的哈勃太空望
500,1666,2509,1666,2509,1766,500,1766,远镜,计算出了这颗卫星的运动轨迹,并借助这一信息,进一步计算得
500,1766,2513,1766,2513,1859,500,1859,到厄里斯的最新密度及轨道数据。结果发现,厄里斯的质量大约
500,1859,1720,1859,1720,1941,500,1941,比冥王星大27%,是目前已知最大的矮行星。
最终处理代码如下:
import numpy as np
import cv2
import os
from glob import glob
import re
from tqdm import tqdm
def get_char_nums(segments):
nums = []
chars = []
for seg in segments:
label_head = seg.split('.')[0]
label_name = label_head + '.txt'
with open(os.path.join(label_root,label_name), 'r', encoding='utf-8') as f:
lines = f.readlines()
nums.append(len(lines[0]))
chars.append(lines[0])
return nums, chars
def addZeros(s_):
head, tail = s_.split('_')
num = ''.join(re.findall(r'\d',tail))
head_num = '0'*(4-len(num)) + num
return head + '_' + head_num + '.jpg'
def strsort(alist):
alist.sort(key=lambda i:addZeros(i))
return alist
def pad(img, headpad, padding):
assert padding>=0
if padding>0:
logi_matrix = np.where(img > 255*0.95, np.ones_like(img), np.zeros_like(img))
ids = np.where(np.sum(logi_matrix, 0) == img.shape[0])
if ids[0].tolist() != []:
pad_array = np.tile(img[:,ids[0].tolist()[-1],:], (1, padding)).reshape((img.shape[0],-1,3))
else:
pad_array = np.tile(np.ones_like(img[:, 0, :]) * 255, (1, padding)).reshape((img.shape[0], -1, 3))
if headpad:
return np.hstack((pad_array, img))
else:
return np.hstack((img, pad_array))
else:
return img
def pad_peripheral(img, pad_size):
assert isinstance(pad_size,tuple)
w, h = pad_size
result = cv2.copyMakeBorder(img, h, h, w, w, cv2.BORDER_CONSTANT, value=[255, 255, 255])
return result
if __name__ == '__main__':
label_roots = ['./labels']
label_dets = ['./fulllabels']
pages_roots = ['./images']
pages_dets = ['./fullimages']
for label_root, label_det, pages_root, pages_det in zip(label_roots, label_dets, pages_roots, pages_dets):
os.makedirs(label_det, exist_ok=True)
os.makedirs(pages_det, exist_ok=True)
pages_for_set = os.listdir(pages_root)
pages_set = set([pfs.split('_')[0] for pfs in pages_for_set])
for ds in tqdm(pages_set):
boxes = []
pages = []
seg_sorted = strsort([d for d in pages_for_set if ds in d])
widths = [cv.imread(os.path.join(pages_root, d)).shape[1] for d in seg_sorted]
heights = [cv.imread(os.path.join(pages_root, d)).shape[0] for d in seg_sorted]
max_width = max(widths)
seg_nums, chars = get_char_nums(seg_sorted)
pad_size = (500, 1000)
w, h = pad_size
label_name = ds + '.txt'
with open(os.path.join(label_det, label_name), 'w') as f:
for i, pg in enumerate(seg_sorted):
headpad = True if i == 0 else True if seg_nums[i] - seg_nums[i - 1] > 5 else False
pg_read = cv.imread(os.path.join(pages_root, pg))
padding = max_width - pg_read.shape[1]
page_new = pad(pg_read, headpad, padding)
pages.append(page_new)
if headpad:
x1 = str(w + padding)
x2 = str(w + max_width)
y1 = str(h + sum(heights[:i + 1]) - heights[i])
y2 = str(h + sum(heights[:i + 1]))
box = np.array([int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)])
else:
x1 = str(w)
x2 = str(w + max_width - padding)
y1 = str(h + sum(heights[:i + 1]) - heights[i])
y2 = str(h + sum(heights[:i + 1]))
box = np.array([int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)])
boxes.append(box.reshape((4, 2)))
char = chars[i]
f.writelines(
x1 + ',' + y1 + ',' + x2 + ',' + y1 + ',' + x2 + ',' + y2 + ',' + x1 + ',' + y2 + ',' + char + '\n')
pages_array = np.vstack(pages)
pages_array = pad_peripheral(pages_array, pad_size)
pages_name = ds + '.jpg'
# cv.polylines(pages_array, [box.astype('int32') for box in boxes], True, (0, 0, 255))
cv.imwrite(os.path.join(pages_det, pages_name), pages_array)
做完了以上准备工作,下面开始分别对手写汉字进行文本检测和文本识别的网络进行搭建和训练。
文本检测
对于手写汉字检测,考虑采用CTPN算法,CTPN是在ECCV 2016中论文Detecting Text in Natural Image with Connectionist Text Proposal Network
(https://arxiv.org/abs/1609.03605)中提出的一种文字检测算法。CTPN是在Faster RCNN的基础上结合CNN与LSTM深度网络,能有效的检测出复杂场景的横向分布的文字。CTPN算法只能检测出横向排列的文字,其结构与Faster R-CNN基本类似,但是加入了LSTM层,网络结构如下图所示。
图片来源:Detecting Text in Natural Image with Connectionist Text Proposal Network
(https://arxiv.org/abs/1609.03605)
代码
CTPN网络代码如下
class CTPN(nn.Cell):
"""
Define CTPN network
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, config, is_training=True):
super(CTPN, self).__init__()
self.config = config
self.is_training = is_training
self.num_step = config.num_step
self.input_size = config.input_size
self.batch_size = config.batch_size
self.hidden_size = config.hidden_size
self.vgg16_feature_extractor = VGG16FeatureExtraction()
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16)
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cast = P.Cast()
# rpn block
self.rpn_with_loss = RPN(config,
self.batch_size,
config.rpn_in_channels,
config.rpn_feat_channels,
config.num_anchors,
config.rpn_cls_out_channels)
self.anchor_generator = AnchorGenerator(config)
self.featmap_size = config.feature_shapes
self.anchor_list = self.get_anchors(self.featmap_size)
self.proposal_generator_test = Proposal(config,
config.test_batch_size,
config.activate_num_classes,
config.use_sigmoid_cls)
self.proposal_generator_test.set_train_local(config, False)
def construct(self, img_data, gt_bboxes, gt_labels, gt_valids, img_metas=None):
x = self.vgg16_feature_extractor(img_data)
x = self.conv(x)
x = self.cast(x, mstype.float16)
x = self.transpose(x, (0, 2, 1, 3))
x = self.reshape(x, (-1, self.input_size, self.num_step))
x = self.transpose(x, (2, 0, 1))
x = self.rnn(x)
rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x, gt_valids)
if self.training:
return rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss
proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list)
return proposal, proposal_mask
def get_anchors(self, featmap_size):
anchors = self.anchor_generator.grid_anchors(featmap_size)
return Tensor(anchors, mstype.float16)
class CTPN_Infer(nn.Cell):
def __init__(self, config):
super(CTPN_Infer, self).__init__()
self.network = CTPN(config, is_training=False)
self.network.set_train(False)
def construct(self, img_data):
output = self.network(img_data, None, None, None, None)
return output
与一般目标检测框架不同,为了能够检测连续文本,加入LSTM结构。因为CNN学习的是感受野内的空间信息, LSTM学习的是序列特征。对于文本序列检测,显然既需要CNN抽象空间特征,也需要序列特征(毕竟文字是连续的)。
代码实现如下,代码在 ./ctpn.py
中
class BiLSTM(nn.Cell):
"""
Define a BiLSTM network which contains two LSTM layers
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, config, is_training=True):
super(BiLSTM, self).__init__()
self.is_training = is_training
self.batch_size = config.batch_size * config.rnn_batch_size
print("batch size is {} ".format(self.batch_size))
self.input_size = config.input_size
self.hidden_size = config.hidden_size
self.num_step = config.num_step
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / self.hidden_size) ** 0.5
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
self.rnn_bw = P.DynamicRNN(forget_bias=0.0)
self.w1 = Parameter(np.random.uniform(-k, k, \
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1")
self.w1_bw = Parameter(np.random.uniform(-k, k, \
(self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1_bw")
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.reverse_seq = P.ReverseV2(axis=[0])
self.concat = P.Concat()
self.transpose = P.Transpose()
self.concat1 = P.Concat(axis=2)
self.dropout = nn.Dropout(0.7)
self.use_dropout = config.use_dropout
self.reshape = P.Reshape()
self.transpose = P.Transpose()
def construct(self, x):
if self.use_dropout:
x = self.dropout(x)
x = self.cast(x, mstype.float16)
bw_x = self.reverse_seq(x)
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
y1_bw, _, _, _, _, _, _, _ = self.rnn_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
y1_bw = self.reverse_seq(y1_bw)
output = self.concat1((y1, y1_bw))
return output
RPN与Faster-RCNN类似,便不再赘述
实验结果
最后finetune的loss结果如下
epoch: 100 step: 1467, rpn_loss: 0.02794, rpn_cls_loss: 0.01963, rpn_reg_loss: 0.01110
某一样本检测结果如图所示,能够通过检测框较为准确的框出文本
训练loss变化如下图所示
文本识别
对于手写汉字识别考虑使用CNN+RNN+CTC(CRNN+CTC)方法进行识别。CNN用于提取图像特征,RNN使用的是 双向的LSTM网络(BiLSTM),用于在卷积特征的基础上继续提取文字序列特征。使用CTCLoss可以解决输出和label 长度不一致的问题,而不用手动去严格对齐。
整个CRNN网络分为三个部分,网络结构如下图所示。
图片来源:CRNN文本识别论文
An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition
https://arxiv.org/pdf/1507.05717.pdf
代码
CRNN部分代码构建如下
"""crnn_ctc network define"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.initializer import TruncatedNormal
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, moving_mean_init=0,
moving_var_init=1)
class Conv(nn.Cell):
def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, use_bn=False, pad_mode='same'):
super(Conv, self).__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
padding=0, pad_mode=pad_mode, weight_init=TruncatedNormal(0.02))
self.bn = _bn(out_channel)
self.Relu = nn.ReLU()
self.use_bn = use_bn
def construct(self, x):
out = self.conv(x)
if self.use_bn:
out = self.bn(out)
out = self.Relu(out)
return out
class VGG(nn.Cell):
"""VGG Network structure"""
def __init__(self, is_training=True):
super(VGG, self).__init__()
self.conv1 = Conv(3, 64, use_bn=True)
self.conv2 = Conv(64, 128, use_bn=True)
self.conv3 = Conv(128, 256, use_bn=True)
self.conv4 = Conv(256, 256, use_bn=True)
self.conv5 = Conv(256, 512, use_bn=True)
self.conv6 = Conv(512, 512, use_bn=True)
self.conv7 = Conv(512, 512, kernel_size=2, pad_mode='valid', use_bn=True)
self.maxpool2d1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')
self.maxpool2d2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1), pad_mode='same')
# self.maxpool2d2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(2, 1), pad_mode='same')
self.bn1 = _bn(512)
def construct(self, x):
x = self.conv1(x)
x = self.maxpool2d1(x)
x = self.conv2(x)
x = self.maxpool2d1(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.maxpool2d2(x)
x = self.conv5(x)
x = self.conv6(x)
x = self.maxpool2d2(x)
x = self.conv7(x)
return x
class CRNN(nn.Cell):
"""
Define a CRNN network which contains Bidirectional LSTM layers and vgg layer.
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
text images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, config):
super(CRNN, self).__init__()
self.batch_size = config.batch_size
self.input_size = config.input_size
self.hidden_size = config.hidden_size
self.num_classes = config.class_num
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / self.hidden_size) ** 0.5
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
self.rnn1_bw = P.DynamicRNN(forget_bias=0.0)
self.rnn2 = P.DynamicRNN(forget_bias=0.0)
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)
w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1 = Parameter(w1.astype(np.float32), name="w1")
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2 = Parameter(w2.astype(np.float32), name="w2")
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1_bw = Parameter(w1_bw.astype(np.float32), name="w1_bw")
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2_bw = Parameter(w2_bw.astype(np.float32), name="w2_bw")
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2_bw")
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
self.fc = nn.Dense(in_channels=self.hidden_size, out_channels=self.num_classes,
weight_init=Tensor(self.fc_weight), bias_init=Tensor(self.fc_bias))
self.fc.to_float(mstype.float32)
self.expand_dims = P.ExpandDims()
self.concat = P.Concat()
self.transpose = P.Transpose()
self.squeeze = P.Squeeze(axis=0)
self.vgg = VGG()
self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)
self.concat1 = P.Concat(axis=2)
self.dropout = nn.Dropout(0.5)
self.rnn_dropout = nn.Dropout(0.9)
self.use_dropout = config.use_dropout
def construct(self, x):
x = self.vgg(x)
shape1 = x.shape
x = self.reshape(x, (self.batch_size, self.input_size, -1))
x = self.transpose(x, (2, 0, 1))
bw_x = self.reverse_seq1(x, self.seq_length)
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
y1_bw = self.reverse_seq2(y1_bw, self.seq_length)
y1_out = self.concat1((y1, y1_bw))
if self.use_dropout:
y1_out = self.rnn_dropout(y1_out)
y2, _, _, _, _, _, _, _ = self.rnn2(y1_out, self.w2, self.b2, None, self.h2, self.c2)
bw_y = self.reverse_seq3(y1_out, self.seq_length)
y2_bw, _, _, _, _, _, _, _ = self.rnn2(bw_y, self.w2_bw, self.b2_bw, None, self.h2_bw, self.c2_bw)
y2_bw = self.reverse_seq4(y2_bw, self.seq_length)
y2_out = self.concat1((y2, y2_bw))
if self.use_dropout:
y2_out = self.dropout(y2_out)
output = ()
for i in range(F.shape(y2_out)[0]):
y2_after_fc = self.fc(self.squeeze(y2[i:i+1:1]))
y2_after_fc = self.expand_dims(y2_after_fc, 0)
output += (y2_after_fc,)
output = self.concat(output)
return output
# return output, shape1, x.shape, y1_out.shape, y2_out.shape, y2_after_fc.shape
def crnn(config, full_precision=False):
"""Create a CRNN network with mixed_precision or full_precision"""
net = CRNN(config)
if not full_precision:
net = net.to_float(mstype.float16)
return net
由于使用CTCLoss,需要加入blank label,用于分隔文本字符,因此识别的文本类别数需要加1。实现代码如下:
'''
Date: 2021-09-05 14:53:34
LastEditors: xgy
LastEditTime: 2021-09-25 22:45:07
FilePath: \code\crnn_ctc\src\loss.py
'''
"""CTC Loss."""
import numpy as np
from mindspore.nn.loss.loss import _Loss
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
class CTCLoss(_Loss):
"""
CTCLoss definition
Args:
max_sequence_length(int): max number of sequence length. For text images, the value is equal to image width
max_label_length(int): max number of label length for each input.
batch_size(int): batch size of input logits
"""
def __init__(self, max_sequence_length, max_label_length, batch_size):
super(CTCLoss, self).__init__()
self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32),
name="sequence_length")
labels_indices = []
for i in range(batch_size):
for j in range(max_label_length):
labels_indices.append([i, j])
self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices")
self.reshape = P.Reshape()
self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True)
def construct(self, logit, label):
labels_values = self.reshape(label, (-1,))
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
return loss
实验结果
识别结果如图所示,能基本正确识别图像中的文本( ::
左边为label,右边为预测结果)
训练loss图像如下所示
评估准则分为两种,分别是字层次的精度和句子层次的精度,得到的精度为:
correct num: 8247 , total num: 10449
Accracy in word: 0.879359924968553
Accracy in sentence: 0.7892621303474017
result: {'CRNNAccuracy': 0.7892621303474017}
问题与解决
问题描述
在训练CRNN+CTCLoss框架进行文本识别时,发现多次出现Loss为0的情况,甚至在前100此迭代中都会出现为0的情 况,这显然是极其不合理的。
根据以往Debug经验和深度学习相关知识,整体解决思路主要分为3部分进行。
查看官方文档、手册及论坛,检查对应函数和接口存在什么限制,检查自己编写代码过程中是否已经满足。
查看Github和Google中别人是否存在类似问题,进行参考,看能否解决可能存在的问题。
编写简单的案例进行尝试,考虑所有可能的情况,找出在什么条件下会出现类似的BUG。
对自己代码由浅到深一步步进行检查,可以采用二分法逐步缩小问题范围。
一般只要按照如上思路进行,最终都能解决。
解决过程
我遵循以上思路,一步步排查问题所在。
查看官网手册中MindSporeCTCLoss
(https://www.mindspore.cn/doc/api_python/zh-CN/r1.2/mindspore/ops/mindspore.ops.CTCLoss.html?highlight=ctcloss#mindspore.ops.CTCLoss)
的说明,发现条件均满足。之后,搜索Google、Github和Mindspore论坛,发现并没有出现类似的情况。因此,我觉得自己先编写简单的案例,进行尝试。
由于数据过大,因此我先定义简单的张量进行尝试,测试什么情况下会出现loss为0的情况。
经过多组尝试,发现当 labels_values 异常时会导致loss取值为0,因为CTCLoss在merge_repeated=true的情况下,不可能出现1411这种情况,两个连续的11中间必定有一个blank label ,至少需要5个logits,例如14141才能有1411这种label,而例子中的输出max time为4,这种label是构造不出来的。
这种情况在手册中很难描述出来,我们必须深入理解CTCLoss的内涵才能发现。
找到了出现Loss为0的情况,则可以把问题定位在label 部分,因此我决定对代码的label进行深入检查。
我使用PYNative模型对代码进行调试,将batch_size设置为1,查看当Loss为0时,输入、输出和label各是什么,发现在label在句子中间部分出现了blank label的情况,但在实际情况是不可能的,说明是数据集label构造出了问题。
再检查数据集代码发现,字典集合中缺失了部分汉字,导致无法正确转为字符label,从而出错。
总结与反思
在实验过程中,会出现各种意想不到的问题,我们不应该惊慌失措,冷静下来,按照步骤一步步确定问题所在,最终一定能够成功Debug。其实很多时候的BUG都是由于自己平时的粗心和代码编写不规范导致的,只要自己养成良好的代码编写习惯,仔细分析问题,就能减少BUG出现频率。
致谢
这次活动促进了开源软件的发展和优秀开源软件社区建设,增加开源项目的活跃度,推进开源生态的发展;感谢开源之夏主办方为这次活动提供的平台与机会。大大提高了代码编写能力,真的受益匪浅!
MindSpore官方资料
GitHub : https://github.com/mindspore-ai/mindspore
Gitee : https : //gitee.com/mindspore/mindspore
官方QQ群 : 871543426