CRNN学习笔记

 

最近学习了CRNN网络,大体训练流程如下:

1、准备输入数据和标签,标签为稀疏矩阵

inputs = tf.placeholder(tf.float32, [batch_size, input_height, input_width, 1])

label = tf.sparse_placeholder(tf.int32, name='label')

seq_len = tf.placeholder(tf.int32, [None], name='seq_len') 

2、通过CNN网络提取特征

cnn_out = self._cnn(inputs)

3、通过2次双向RNN,得到神经网络输出结果

crnn_model = self._rnn(cnn_out, self._seq_len)

4、根据最终字符的类别得到最终的输出

logits = tf.reshape(crnn_model, [-1, 512])

W = tf.Variable(tf.truncated_normal([512, self._class_num], stddev=0.1), name="W")
b = tf.Variable(tf.constant(0., shape=[self._class_num]), name="b")

logits = tf.matmul(logits, W) + b

logits = tf.reshape(logits, [self._batch_size, -1, self._class_num])

        # 网络层输出
net_output = tf.transpose(logits, (1, 0, 2))

5、解析网络输出,其中decoded[0]是一个稀疏张量,类型和label一样

decoded, log_prob = tf.nn.ctc_greedy_decoder(net_output, self._seq_len)

6、损失函数loss

with tf.name_scope('loss'):
    loss = tf.nn.ctc_loss(self._label, self._net_output, self._seq_len)
    loss = tf.reduce_mean(loss)

7、优化器optimizer

with tf.name_scope('optimizer'):
    train_op = tf.train.AdamOptimizer(self._learning_rate).minimize(loss)

8、准确率accuracy

with tf.name_scope('accuracy'):
    accuracy = 1 - tf.reduce_mean(tf.edit_distance(tf.cast(self._decoded[0], tf.int32), self._label))
    accuracy_broad = tf.summary.scalar("accuracy", accuracy)

9、喂数据进行训练

feed_dict = {self._inputs: batch_data,self._label: batch_label, \
self._seq_len: [self._max_char_count] * self.batch_size}
sess.run(train_op, feed_dict=feed_dict)

 

你可能感兴趣的:(Tensorflow,深度学习)