手写中文文章识别(1)——问题描述
https://blog.csdn.net/foreseerwang/article/details/80833749
手写中文文章识别(2)——样本集构建
https://blog.csdn.net/foreseerwang/article/details/80842498
前文提到,样本集构建会形成文章文件(.code/.char/.len)和手写中文图片文件,本文使用这些文件形成模型训练及validation所需的输入数据,即dataset文件。关于tensorflow Dataset使用方法,可参见之前的两篇文章:
https://blog.csdn.net/foreseerwang/article/details/80170210
https://blog.csdn.net/foreseerwang/article/details/80572182
使用Dataset的data feeding代码如下:
import os
import random
import numpy as np
import tensorflow as tf
import io
try:
import cPickle as pickle
except ImportError:
import pickle
# 图片augment设置参数
tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast")
# 图片及训练相关参数
tf.app.flags.DEFINE_integer('image_size', 64, "手写中文图片边长,方形。")
tf.app.flags.DEFINE_integer('image_channel', 1, "手写中文图片通道数,1代表黑白图片")
tf.app.flags.DEFINE_boolean('gray', True, "是否修改为灰度")
tf.app.flags.DEFINE_integer('shuffle_size', 100, '数据集的shuffle缓存大小')
tf.app.flags.DEFINE_integer('sent_len_max', 10,
"每句话的最长尺寸,超过这个长度时,逐段形成数据集;同时,训练集中的每句话也都需要padded到这个长度")
tf.app.flags.DEFINE_integer('batch_size', 3, '生成的数据集batch size')
tf.app.flags.DEFINE_integer('eval_steps', 2, 'validation间隔,即每eval_steps次训练batch进行一次validation')
# 是否使用短字典及相关配置。完整字库长度7356,短字库长度4000,已可以涵盖约95%以上的常见字
# 本程序使用完整字典,可不用关心此处的配置
# 抱歉,受限于版权问题,无法上传手写中文字库
tf.app.flags.DEFINE_boolean('short_dict', False, 'whether to use short dict')
tf.app.flags.DEFINE_integer('charset_size_long', 7356, "Long character dictionary size")
tf.app.flags.DEFINE_integer('charset_size_short', 4000, "Short character dictionary size")
tf.app.flags.DEFINE_string('char_dict_long', './article_recog/char_dict_gbk_rvs20180518',
'The reversed long character dictionary: code-->char')
tf.app.flags.DEFINE_string('char_dict_short', './article_recog/char_dict_4000_rvs',
'The reversed short character dictionary: code-->char')
# 数据储存目录
# 抱歉,受限于版权问题,这部分数据不能上传
tf.app.flags.DEFINE_string('sample_dir', './sample', '样本集存储目录')
tf.app.flags.DEFINE_string('train_hwdb_dir', './hwdb/hwdb_by_char_Train_gbk', '训练用hwdb手写中文图库存储目录')
tf.app.flags.DEFINE_string('test_hwdb_dir', './hwdb/hwdb_by_char_Test_gbk', '测试用hwdb手写中文图库存储目录')
FLAGS = tf.app.flags.FLAGS
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# 根据是否使用短字典,选择不同的字典文件和尺寸
if FLAGS.short_dict:
charset_size = FLAGS.charset_size_short
char_dict_file = FLAGS.char_dict_short
else:
charset_size = FLAGS.charset_size_long
char_dict_file = FLAGS.char_dict_long
class DataIterator:
def __init__(self, filenames, istrain=True):
self.filenames = filenames
if istrain:
self.hwdb_dir=FLAGS.train_hwdb_dir
else:
self.hwdb_dir=FLAGS.test_hwdb_dir
# 图片augment处理子程序
# 只需要针对训练集做augment
# 应用中发现,这里的augment还远远不够...
@staticmethod
def data_augmentation(images, labels, lengths, masks):
if FLAGS.random_flip_up_down:
images = tf.image.random_flip_up_down(images)
if FLAGS.random_brightness:
images = tf.image.random_brightness(images, max_delta=0.3)
if FLAGS.random_contrast:
images = tf.image.random_contrast(images, 0.8, 1.2)
return images, labels, lengths, masks
# 如果使用短字典,需要对输出的数据集做编码转换。
@staticmethod
def dataset_convert(char_vec,char_code,sent_len,mask):
if FLAGS.short_dict:
char_dict_dir = './article_recog'
char_dict_map_filename = 'char_dict_map' # long num --> short num
char_dict_map_fullfilename = os.path.join(char_dict_dir, char_dict_map_filename)
fr = open(char_dict_map_fullfilename, 'rb')
char_dict_map = pickle.load(fr)
fr.close()
converted_char_code = []
for ii in range(sent_len):
code_long = char_code[ii]
try:
code = char_dict_map[code_long]
except KeyError:
code = charset_size - 1
converted_char_code.append(code)
return (np.asarray(char_vec, dtype='float32'),
np.asarray(converted_char_code, dtype='int32'),
sent_len,
np.asarray(mask, dtype='bool'))
else:
return (np.asarray(char_vec, dtype='float32'),
np.asarray(char_code, dtype='int32'),
sent_len,
np.asarray(mask, dtype='bool'))
# 读取hwdb手写中文图片库子程序
# hwdb手写中文图片库按照文字保存,所有书写人写的同一个中文文字,按顺序保存到同一个文件中
# 读取时需要随机选择其中一个书写人的图片数据输出
@staticmethod
def read_hwdb(filename):
with open(filename, 'rb') as f:
char_file = np.fromfile(f, dtype='uint8')
img_byte = FLAGS.image_size*FLAGS.image_size*FLAGS.image_channel
char_file_len = len(char_file)
if (char_file_len % (img_byte)) != 0:
raise ValueError("Characters file %s error" % filename)
char_num = char_file_len // (img_byte)
char_idx = np.random.randint(char_num)
char_mat_uint8 = char_file[char_idx*img_byte:(char_idx+1)*img_byte]
char_mat = char_mat_uint8.astype(np.float32)
depth_major = char_mat.reshape([FLAGS.image_channel, FLAGS.image_size, FLAGS.image_size])
image = depth_major.transpose([1, 2, 0])
return image
# dataset generator子程
# 全python代码,因此可以非常灵活的进行各种文档读取即数据转换处理
# 读取字符(.char)、编码(.code)、长度(.len)文件,处理并输出dataset所需的image/label/length/mask数组
# image:本句话里所有文字对应的图片,[length, image_size, image_size, image_channel]
# label:本句话里所有的图片对应的字符编码,[length]
# length:本句话的字符长度,标量,非矩阵
# mask: 本句话长度所对应的True数组,[length],所有元素均为True
# 后来发现mask数据可以不用,可以使用tf.sequence_mask函数随时从len生成mask
def file_readline(self):
for filename in self.filenames:
char_file = os.path.join(FLAGS.sample_dir, filename + '.char')
code_file = os.path.join(FLAGS.sample_dir, filename + '.code')
len_file = os.path.join(FLAGS.sample_dir, filename + '.len')
frchar = io.open(char_file, 'r', encoding='utf-8')
frcode = io.open(code_file, 'r', encoding='utf-8')
frlen = io.open(len_file, 'r', encoding='utf-8')
try:
while True:
chars = frchar.readline()
codes = frcode.readline()
len_in_char = frlen.readline()
sent_len = int(len_in_char)
char_list = chars[:-1]
code_list = codes.split()
if len(char_list) != sent_len or len(code_list) != sent_len:
print(sent_len)
print(len(char_list))
print(len(code_list))
raise ValueError("Characters or labels length error")
if sent_len>FLAGS.sent_len_max:
while sent_len > FLAGS.sent_len_max:
char_code = []
char_vec = []
for code in code_list[-sent_len:-sent_len+FLAGS.sent_len_max]:
char_code.append(int(code))
char_filename = os.path.join(self.hwdb_dir, code + '.char')
char_vec.append(self.read_hwdb(char_filename))
mask = np.ones(FLAGS.sent_len_max, dtype='bool')
sent_len -= FLAGS.sent_len_max
yield self.dataset_convert(char_vec,char_code,FLAGS.sent_len_max,mask)
char_code = []
char_vec = []
for code in code_list[-sent_len:]:
char_code.append(int(code))
char_filename = os.path.join(self.hwdb_dir, code + '.char')
char_vec.append(self.read_hwdb(char_filename))
mask = np.ones(sent_len, dtype='bool')
yield self.dataset_convert(char_vec,char_code,sent_len,mask)
else:
char_code = []
char_vec = []
for code in code_list:
char_code.append(int(code))
char_filename = os.path.join(self.hwdb_dir, code+'.char')
char_vec.append(self.read_hwdb(char_filename))
mask = np.ones(sent_len, dtype='bool')
yield self.dataset_convert(char_vec,char_code,sent_len,mask)
except ValueError:
pass
frchar.close()
frcode.close()
frlen.close()
# dataset生成函数
# 需要注意的是,在输出dataset之前,需要进行padding,把image/label/mask长度pad到FLAGS.sent_len_max
def input_pipeline(self, batch_size, num_epochs=None, aug=False, shuffle=False):
char_dataset = tf.data.Dataset.from_generator(self.file_readline,
(tf.float32,tf.int32,tf.int32,tf.bool),
(tf.TensorShape([None,FLAGS.image_size,FLAGS.image_size,FLAGS.image_channel]),
tf.TensorShape([None]),tf.TensorShape([]),tf.TensorShape([None])))
if aug:
char_dataset = char_dataset.map(self.data_augmentation)
char_dataset = char_dataset.repeat(num_epochs)
if shuffle:
char_dataset = char_dataset.shuffle(FLAGS.shuffle_size)
char_dataset = char_dataset.padded_batch(
batch_size,
padded_shapes=(tf.TensorShape([FLAGS.sent_len_max,FLAGS.image_size,
FLAGS.image_size,FLAGS.image_channel]),
tf.TensorShape([FLAGS.sent_len_max]),
tf.TensorShape([]),
tf.TensorShape([FLAGS.sent_len_max])),
padding_values=(0.,0, 0, False))
iterator = char_dataset.make_one_shot_iterator()
databatch = iterator.get_next()
return databatch
# datafeeding测试程序
# 读取相应文件内容,形成batch输出并打印
def datafeeding_test(train_files, valid_files):
train_feeder = DataIterator(train_files, istrain=True)
valid_feeder = DataIterator(valid_files, istrain=False)
# 这里得到的batch,可以直接输入到模型中,而不用使用placehoder
# 训练集
trn_dataset = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)
train_image_batch = trn_dataset[0]
train_label_batch = trn_dataset[1]
train_len_batch = trn_dataset[2]
train_mask_batch = trn_dataset[3]
# validation集
val_dataset = valid_feeder.input_pipeline(batch_size=FLAGS.batch_size)
valid_image_batch = val_dataset[0]
valid_label_batch = val_dataset[1]
valid_len_batch = val_dataset[2]
valid_mask_batch = val_dataset[3]
fr = open(char_dict_file, 'rb')
char_dict = pickle.load(fr)
fr.close()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
training_steps = 0
while True:
# 实际程序中,此处放置每个batch的训练代码,当前放置打印代码,验证dataset输出正确
trn_images, trn_labels, trn_lens, trn_masks = sess.run(
[train_image_batch, train_label_batch,
train_len_batch, train_mask_batch])
print('!!! Train batch #%d !!!' % training_steps)
print('--Shape of train images batch: (%d,%d,%d,%d,%d)' % trn_images.shape)
print('--Corresponding sentence lengths of the batch:')
print(trn_lens)
print('--Corresponding labels of the batch:')
for ii in range(FLAGS.batch_size):
for jj in range(trn_lens[ii]):
print('%5d' % trn_labels[ii,jj]),
print('')
print('--Corresponding characters of the batch:')
for ii in range(FLAGS.batch_size):
for jj in range(trn_lens[ii]):
print(char_dict[trn_labels[ii,jj]]),
print('')
print('')
training_steps += 1
# 每FLAGS.eval_steps个训练batch后进行一次validation
if training_steps%FLAGS.eval_steps == 0:
val_images, val_labels, val_lens, val_masks = sess.run(
[valid_image_batch, valid_label_batch,
valid_len_batch, valid_mask_batch])
print('### Validation batch #%d ###' % (training_steps//FLAGS.eval_steps-1))
print('--Shape of validation images batch: (%d,%d,%d,%d,%d)' % val_images.shape)
print('--Corresponding sentence lengths of the batch:')
print(val_lens)
print('--Corresponding labels of the batch:')
for ii in range(FLAGS.batch_size):
for jj in range(val_lens[ii]):
print('%5d' % val_labels[ii,jj]),
print('')
print('--Corresponding characters of the batch:')
for ii in range(FLAGS.batch_size):
for jj in range(val_lens[ii]):
print(char_dict[val_labels[ii,jj]]),
print('')
print('')
# 仅用于测试,因此输出3个测试batch后即终止
if training_steps == 3:
break
except tf.errors.OutOfRangeError:
print('==================Finished================')
def main(_):
train_filelist = ['trn1', 'trn2']
valid_filelist = ['val']
datafeeding_test(train_filelist, valid_filelist)
if __name__ == "__main__":
tf.app.run()
上述代码中用到的trn1.char、trn2.char和val.char文件内容如下(相应的.code和.len文件内容就不贴出来了,都很简单,用于示例):
trn1.char
钻石闪烁的光芒照射着世间人们日益懦弱的心灵。
那成功让钻石增值的一刀似乎在昭示着一个发人深省的道理:
trn2.char
人生需要勇气,
需要平常心。
val.char
手就不会发抖,
心也更会坚定。
因此,
抛开杂念,
用勇气成就人生吧!
输出结果(很抱歉,由于实际手写图片数据不能上传,读者无法直接运行,但可参照文件内容及如下输出理解上述data feeding代码):
!!! Train batch #0 !!!
--Shape of train images batch: (3,10,64,64,1)
--Corresponding sentence lengths of the batch:
[10 10 2]
--Corresponding labels of the batch:
6493 4187 6674 3559 4061 528 5115 3607 1631 4131
185 6683 278 312 2586 4082 2178 1920 4061 1976
3526 158
--Corresponding characters of the batch:
钻 石 闪 烁 的 光 芒 照 射 着
世 间 人 们 日 益 懦 弱 的 心
灵 。
!!! Train batch #1 !!!
--Shape of train images batch: (3,10,64,64,1)
--Corresponding sentence lengths of the batch:
[10 10 7]
--Corresponding labels of the batch:
6301 2189 685 5771 6493 4187 1341 470 4061 169
616 353 222 1197 2622 4295 4131 169 198 815
278 3342 4106 4061 6273 3822 24
--Corresponding characters of the batch:
那 成 功 让 钻 石 增 值 的 一
刀 似 乎 在 昭 示 着 一 个 发
人 深 省 的 道 理 :
### Validation batch #0 ###
--Shape of validation images batch: (3,10,64,64,1)
--Corresponding sentence lengths of the batch:
[7 7 3]
--Corresponding labels of the batch:
2220 1649 178 334 815 2258 10
1976 234 2683 334 1221 1581 158
1173 3028 10
--Corresponding characters of the batch:
手 就 不 会 发 抖 ,
心 也 更 会 坚 定 。
因 此 ,
!!! Train batch #2 !!!
--Shape of train images batch: (3,10,64,64,1)
--Corresponding sentence lengths of the batch:
[ 7 6 10]
--Corresponding labels of the batch:
278 3908 6810 5719 702 3090 10
6810 5719 1836 1820 1976 158
6493 4187 6674 3559 4061 528 5115 3607 1631 4131
--Corresponding characters of the batch:
人 生 需 要 勇 气 ,
需 要 平 常 心 。
钻 石 闪 烁 的 光 芒 照 射 着