RNN LSTM 网络参数问题

问题一:在NLP任务中,词向量维度(embedding size)是否一定要等于LSTM隐藏层节点数(hidden size)?

词向量(Word Embedding)可以说是自然语言处理任务的基石,运用深度学习的自然语言处理任务更是离不开词向量的构造工作。在文本分类,问答系统,机器翻译等任务中,LSTM的展开步数(num_step)为输入语句的长度,而每一个LSTM单元的输入则是语句中对应单词或词组的词向量。

对于embedding size是否一定要等于LSTM的hidden size 这样一个问题,我们可以通过了解单个LSTM单元的原理来进行回答。

我们输入LSTM的 input vector,也就是每个单词的word embedding这里称为vector A,LSTM的三个gate的控制是通过vector A来控制的,具体方法是通过乘以权重矩阵(weight),再加上偏置值(bias)形成新的一个vector,这个vector我们可以理解成gate的控制信号。而控制三个gate就需要三组不同的weight 和bias。LSTM传入神经网络输入层输入的vector(称为vector B)跟产生的三个控制信号的方法一样,也是通过vector A乘以一组weight 和bias产生。

这里放几张图片再来解释说明一下。(图片源于台大李弘毅老师的PPT)
RNN LSTM 网络参数问题_第1张图片

问题二:在多层LSTM中,词向量维度(embedding size)是否一定要等于LSTM隐藏层节点数(hidden size)?

根据问题一的解释,多层LSTM的embedding size实际上也不必等于hidden size的,可以通过设置不同shape的weight和bias来实现(第一层与后续层数设置为不同)。

但在Tensorflow实现多层LSTM时,使用的函数tf.contrib.rnn.MultiRNNcell()会自动将累加的LSTM的参数设为相同的shape,或者说,是模块化的直接累加LSTM层数。这样得来的多层LSTM网络,其参数weight和bias的shape都是相同的,所以当设置不同的embedding size和hidden size时会报错,更改为相同值时error消失。

在NLP任务中通常需要使用预先训练好的词向量来加快训练速度,而LSTM的hidden size也是在训练调参时需要进行调整的重要参数,所以还在寻找如何解决Tensorflow中多层lstm的hidden size和embedding size不相等的问题。

问题三:LSTM中,关于cell state 和 hidden state

TensorFlow中使用tf.contrib.rnn.BasicLSTMCell()的方法来搭建LSTM网络,其中有一项参数为state_is_tuple,官方建议设置为True,这个参数的实际就是使LSTM的state以tuple的形式输出,shape为batch size和num_step。也就是min batch的大小和LSTM展开的步数,而state分为c与h,c为cell state,即memory的state,h为hidden state,也就是lstm的最终输出。

LSTM单元根据输入的序列长度进行展开,当前时刻得到的state又作为参数传入下一时刻。代码如下:

        with tf.variable_scope("LSTM_layer"):
            for time_step in range(num_step):
                if time_step>0: tf.get_variable_scope().reuse_variables()
                (cell_output,state)=cell(inputs[:,time_step,:],state)
                out_put.append(cell_output)

在网络训练时,通常需要在每个batch后对LSTM的cell state行进重置和归零,代码如下所示。

        state = session.run(model._initial_state)
        for i , (c,h) in enumerate(model._initial_state):
            feed_dict[c]=state[i].c
            feed_dict[h]=state[i].h

https://blog.csdn.net/ZJRN1027/article/details/80301039

你可能感兴趣的:(RNN LSTM 网络参数问题)