tensorflow学习之LSTMCell详解(Class tf.contrib.rnn.LSTMCell与Class tf.contrib.rnn.BasicLSTMCell的区别)

转载:https://blog.csdn.net/u013230189/article/details/82811066

Class tf.contrib.rnn.LSTMCell

继承自:LayerRNNCell

Aliases:

Class tf.contrib.rnn.LSTMCell
Class tf.nn.rnn_cell.LSTMCell
长短时记忆单元循环网络单元。默认的non-peephole是基于http://www.bioinf.jku.at/publications/older/2604.pdf

S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.

实现的,peephole的实现基于https://research.google.com/pubs/archive/43905.pdf

Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.

该类使用可选的peephole连接,可选的单元裁剪(cell clipping),可选的投影层。

LSTMCell与BasicLSTMCell的区别是该类增加了可选的peephole连接、单元裁剪(cell clipping)、投影层。

LSTMCell构造函数:

__init__(

    num_units,

    use_peepholes=False,

    cell_clip=None,

    initializer=None,

    num_proj=None,

    proj_clip=None,

    num_unit_shards=None,

    num_proj_shards=None,

    forget_bias=1.0,

    state_is_tuple=True,

    activation=None,

    reuse=None,

    name=None,

    dtype=None

)

参数说明:

num_units:LSTM cell中的单元数量,即隐藏层神经元数量。
use_peepholes:布尔类型,设置为True则能够使用peephole连接
cell_clip:可选参数,float类型,如果提供,则在单元输出激活之前,通过该值裁剪单元状态。
Initializer:可选参数,用于权重和投影矩阵的初始化器。
num_proj:可选参数,int类型,投影矩阵的输出维数,如果为None,则不执行投影。
pro_clip:可选参数,float型,如果提供了num_proj>0和proj_clip,则投影值将元素裁剪到[-proj_clip,proj_clip]范围。
num_unit_shards:弃用。
num_proj_shards:弃用。
forget_bias:float类型,偏置增加了忘记门。从CudnnLSTM训练的检查点(checkpoin)恢复时,必须手动设置为0.0。
state_is_tuple:如果为True,则接受和返回的状态是c_state和m_state的2-tuple;如果为False,则他们沿着列轴连接。后一种即将被弃用。
activation:内部状态的激活函数。默认为tanh
reuse:布尔类型,描述是否在现有范围中重用变量。如果不为True,并且现有范围已经具有给定变量,则会引发错误。
name:String类型,层的名称。具有相同名称的层将共享权重,但为了避免错误,在这种情况下需要reuse=True.
dtype:该层默认的数据类型。默认值为None表示使用第一个输入的类型。在call之前build被调用则需要该参数。
BasicLSTMCell构造函数:

__init__(

    num_units,

    forget_bias=1.0,

    state_is_tuple=True,

    activation=None,

    reuse=None,

    name=None,

    dtype=None

)

LSTMCell源码:

# i = input_gate, j = new_input, f = forget_gate, o = output_gate

    lstm_matrix = self._linear1([inputs, m_prev])

    i, j, f, o = array_ops.split(

        value=lstm_matrix, num_or_size_splits=4, axis=1)

    # Diagonal connections

    if self._use_peepholes and not self._w_f_diag:

      scope = vs.get_variable_scope()

      with vs.variable_scope(

          scope, initializer=self._initializer) as unit_scope:

        with vs.variable_scope(unit_scope):

          self._w_f_diag = vs.get_variable(

              "w_f_diag", shape=[self._num_units], dtype=dtype)

          self._w_i_diag = vs.get_variable(

              "w_i_diag", shape=[self._num_units], dtype=dtype)

          self._w_o_diag = vs.get_variable(

              "w_o_diag", shape=[self._num_units], dtype=dtype)

 

    if self._use_peepholes:

      c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +

           sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))

    else:

      c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *

           self._activation(j))

 

    if self._cell_clip is not None:

      # pylint: disable=invalid-unary-operand-type

      c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)

      # pylint: enable=invalid-unary-operand-type

    if self._use_peepholes:

      m = sigmoid(o + self._w_o_diag * c) * self._activation(c)

    else:

      m = sigmoid(o) * self._activation(c)

 

    if self._num_proj is not None:

      if self._linear2 is None:

        scope = vs.get_variable_scope()

        with vs.variable_scope(scope, initializer=self._initializer):

          with vs.variable_scope("projection") as proj_scope:

            if self._num_proj_shards is not None:

              proj_scope.set_partitioner(

                  partitioned_variables.fixed_size_partitioner(

                      self._num_proj_shards))

            self._linear2 = _Linear(m, self._num_proj, False)

      m = self._linear2(m)

 

      if self._proj_clip is not None:

        # pylint: disable=invalid-unary-operand-type

        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)

        # pylint: enable=invalid-unary-operand-type

 

    new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else

                 array_ops.concat([c, m], 1))

    return m, new_state

从上面构造函数和源码可以看出LSTMCell和BasicLSTMCell的区别:

增加了use_peepholes, bool值,为True时增加peephole。


增加了cell_clip, 浮点值,把cell的值限制在 ±cell_clip内
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
增加了num_proj(int)和proj_clip(float), 相对于BasicLSTMCell,在输出m计算完之后增加了一层线性变换,并限制了输出的值
m = _linear(m, self._num_proj, bias=False, scope=scope)
 
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
代码实例:

import tensorflow as tf
 
batch_size=10
 
embedding_dim=300
 
inputs=tf.Variable(tf.random_normal([batch_size,embedding_dim]))
 
previous_state=(tf.Variable(tf.random_normal([batch_size,128])),tf.Variable(tf.random_normal([batch_size,128])))
 
lstmcell=tf.nn.rnn_cell.LSTMCell(128)
 
outputs,(h_state,c_state)=lstmcell(inputs,previous_state)
 
 
 
print(outputs.shape) #(10, 128)
 
print(h_state.shape) #(10, 128)
 
print(c_state.shape) #(10, 128)
 

你可能感兴趣的:(tensorflow学习之LSTMCell详解(Class tf.contrib.rnn.LSTMCell与Class tf.contrib.rnn.BasicLSTMCell的区别))