RNN代码解读之char-RNN with TensorFlow(model.py)

此工程解读链接(建议按顺序阅读):
RNN代码解读之char-RNN with TensorFlow(model.py)
RNN代码解读之char-RNN with TensorFlow(train.py)
RNN代码解读之char-RNN with TensorFlow(util.py)
RNN代码解读之char-RNN with TensorFlow(sample.py)

最近一直在学习RNN的相关知识,个人认为相比于CNN各种模型在detection/classification/segmentation等方面超人的表现,RNN还有很长的一段路要走,毕竟现在的nlp模型单从output质量上来看只是差强人意,要和人相比还有一段距离。CNN+RNN的任务比如image caption更是有很多有待研究和提高的地方。

关于对CNN和RNN相关内容的学习和探讨,我将会在近期更新对一些经典论文的解读以及自己的看法,届时欢迎大家给予指导。

当然,CS231n中有一句名言“Don’t think too hard, just cross your fingers.” 想法还是要落地才可以看到成果,那么我们今天就一起来看一下大牛Adrew Karparthy的char-RNN模型,AK使用lua基于torch写的,git上已经有人及时的复现了TensorFlow with Python版本(https://github.com/sherjilozair/char-rnn-tensorflow)。

网上已经有很多相关的解析了,但大部分只是针对model进行解释,这对于整体模型的宏观理解以及TensorFlow的学习都是很不利的。因此,这里我会给出自己对所有代码的理解,若有错误欢迎及时指正。

这一个版本的代码共分为四个模块:model.py,train.py, util.py以及sample.py,我们将按照这个顺序,分四篇博文对四个模块进行梳理。我在代码中对所有我认为重要的地方都写了注释,有的部分甚至每一行都有明确的注释,但难免有的基本方法会让人产生疑惑。面对这种问题,我强烈建议大家一边debug一步一步的执行看结果,一边百度或者google。这样梳理一遍代码一定会全身舒畅,豁然开朗,感觉打开了新世界的大门,对于RNN模型的TensorFlow实现也会更有把握。

当然理解这一个工程并不是我们的终极目的,针对后面跟新的paper中提到的有创新的方法,我们也会再此模型的基础上进一步实现,走上我们的科研之路。

废话说太多了,下面我们先开始看最重点的model.py
注意:这里注释解释的只是训练过程中的理解,在infer过程中batch=1,sequence=1,大体理解没有差别,但是具体思想还需要大家到时候再推敲推敲。此外,此class中的sample方法这一节不讨论,到第四节sample.py的时候一并讨论。

#-*-coding:utf-8-*-
import tensorflow as tf
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import seq2seq

import numpy as np

class Model():
    def __init__(self, args, infer=False):
        self.args = args
        #在测试状态下(inference)才用如下选项
        if infer:
            args.batch_size = 1
            args.seq_length = 1
        #几种备选的rnn类型
        if args.model == 'rnn':
            cell_fn = rnn_cell.BasicRNNCell
        elif args.model == 'gru':
            cell_fn = rnn_cell.GRUCell
        elif args.model == 'lstm':
            cell_fn = rnn_cell.BasicLSTMCell
        else:
            raise Exception("model type not supported: {}".format(args.model))
        #固定格式是例:cell = rnn_cell.GRUCelll(rnn_size)
        #rnn_size指的是每个rnn单元中的神经元个数(虽然RNN途中只有一个圆圈代表,但这个圆圈代表了rnn_size个神经元)
        #这里state_is_tuple根据官网解释,每个cell返回的h和c状态是储存在一个list里还是两个tuple里,官网建议设置为true
        cell = cell_fn(args.rnn_size, state_is_tuple=True)
        #固定格式,有几层rnn
        self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
        #input_data&target(标签)格式:[batch_size, seq_length]
        self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
        self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
        #cell的初始状态设为0,因为在前面设置cell时,cell_size已经设置好了,因此这里只需给出batch_size即可
        #(一个batch内有batch_size个sequence的输入)
        self.initial_state = cell.zero_state(args.batch_size, tf.float32)

        #rnnlm = recurrent neural network language model
        #variable_scope就是变量的作用域
        with tf.variable_scope('rnnlm'):
            #softmax层的参数
            softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
            softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
            with tf.device("/cpu:0"):
                #推荐使用tf.get_variable而不是tf.variable
                #embedding矩阵是将输入转换到了cell_size,因此这样的大小设置
                embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
                #关于tf.nn.embedding_lookup(embedding, self.input_data):
                #   调用tf.nn.embedding_lookup,索引与train_dataset对应的向量,相当于用train_dataset作为一个id,去检索矩阵中与这个id对应的embedding
                #将第三个参数,在第1维度,切成seq_length长短的片段
                #embeddinglookup得到的look_up尺寸是[batch_size, seq_length, rnn_size],这里是[50,50,128]
                look_up = tf.nn.embedding_lookup(embedding, self.input_data)
                #将上面的[50,50,128]切开,得到50个[50,1,128]的inputs
                inputs = tf.split(1, args.seq_length, look_up)
                #之后将 1 squeeze掉,50个[50,128]
                inputs = [tf.squeeze(input_, [1]) for input_ in inputs]

        #在infer的时候方便查看
        def loop(prev, _):
            prev = tf.matmul(prev, softmax_w) + softmax_b
            prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
            return tf.nn.embedding_lookup(embedding, prev_symbol)

        #seq2seq.rnn_decoder基于schedule sampling实现,相当于一个黑盒子,可以直接调用
        #得到的两个参数shape均为50个50*128的张量,和输入是一样的
        outputs, last_state = seq2seq.rnn_decoder(inputs,
                                                  self.initial_state, cell,
                                                  loop_function=loop if infer else None,
                                                  scope='rnnlm')
        #将outputsreshape在一起,形成[2500,128]的张量
        output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
        #logits和probs的大小都是[2500,65]([2500,128]*[128,65])
        self.logits = tf.matmul(output, softmax_w) + softmax_b
        self.probs = tf.nn.softmax(self.logits)
        #得到length为2500的loss(即每一个batch的sequence中的每一个单词输入,都会最终产生一个loss,50*50=2500)
        loss = seq2seq.sequence_loss_by_example([self.logits],
                [tf.reshape(self.targets, [-1])],
                [tf.ones([args.batch_size * args.seq_length])],
                args.vocab_size)
        #得到一个batch的cost后面用于求梯度
        self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
        #将state转换一下,便于下一次继续训练
        self.final_state = last_state
        #因为学习率不需要BPTT更新,因此trainable=False
        #具体的learning_rate是由train.py中args参数传过来的,这里只是初始化设了一个0
        self.lr = tf.Variable(0.0, trainable=False)
        #返回了包括前面的softmax_w/softmax_b/embedding等所有变量
        tvars = tf.trainable_variables()
        #求grads要使用clip避免梯度爆炸,这里设置的阈值是5(见args)
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
                args.grad_clip)
        #使用adam优化方法
        optimizer = tf.train.AdamOptimizer(self.lr)
        #参考tensorflow手册,
        # 将计算出的梯度应用到变量上,是函数minimize()的第二部分,返回一个应用指定的梯度的操作Operation,对global_step做自增操作
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

以上就是对于model.py的代码分析,总体来说就是“模型定义+参数设置+优化”的思路,如果有哪里出错还望大家多多指教啦~!

参考资料:
http://blog.csdn.net/mydear_11000/article/details/52776295
https://github.com/sherjilozair/char-rnn-tensorflow
http://www.tensorfly.cn/tfdoc/api_docs/python/constant_op.html#truncated_normal

你可能感兴趣的:(RNN,TensorFlow)