tensorflow.nn.bidirectional_dynamic_rnn()函数的用法

转载请标明出处:http://blog.csdn.net/wuzqchom/article/details/75453327

在参加知乎比赛的时候感觉CNN调参已经差不多快到天花板了,于是试下双向的RNN。使用tensorflow.nn.bidirectional_dynamic_rnn()这个函数,就可以很方便的实现双向LSTM,很简洁。
首先来看一下,函数:

def bidirectional_dynamic_rnn(
cell_fw, # 前向RNN
cell_bw, # 后向RNN
inputs, # 输入
sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
initial_state_fw=None,  # 前向的初始化状态(可选)
initial_state_bw=None,  # 后向的初始化状态(可选)
dtype=None, # 初始化和输出的数据类型(可选)
parallel_iterations=None,
swap_memory=False, 
time_major=False,
# 决定了输入输出tensor的格式:如果为true, 
向量的形状必须为 `[max_time, batch_size, depth]`. 
# 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`. 
scope=None
)


而cell_fw和cell_bw的定义是完全一样的。如果这两个cell选LSTM cell整个结构就是双向LSTM了。

# lstm模型 正方向传播的RNN
lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(embedding_size, forget_bias=1.0)
# 反方向传播的RNN
lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(embedding_size, forget_bias=1.0)
  • 返回值:
    一个(outputs, output_states)的元组
    其中,
    1. outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设
    time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。
    2. output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。
    output_state_fw和output_state_bw的类型为LSTMStateTuple。
    LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。


但是看来看去,输入两个cell都是相同的啊?
其实在bidirectional_dynamic_rnn函数的内部,会把反向传播的cell使用array_ops.reverse_sequence的函数将输入的序列逆序排列,使其可以达到反向传播的效果。
在实现的时候,我们是需要传入两个cell作为参数就可以了:

 (outputs, output_states) = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, 
embedded_chars,  dtype=tf.float32)

embedded_chars为输入的tensor,[batch_szie, max_time, depth]。batch_size为模型当中batch的大小,应用在文本中时,max_time可以为句子的长度(一般以最长的句子为准,短句需要做padding),depth为输入句子词向量的维度。


当然你也可以使用循环的方式,时长为句子的长度,每一次都以上一时刻(假设一词为句子的基本单位的话,即上一个词)的隐藏状态和当前时刻的tensor为输入,但是这样写的时候相对会比较麻烦,若使用bidirectional_dynamic_rnn()则会清爽很多。

本篇仅仅是在应用接口层面介绍了bidirectional_dynamic_rnn,内部实现并没有做过多的探讨,dynamic_rnn()函数也有一些工程上的优化,比如加入buckets机制。
具体解释见知乎问题:tensorflow中的seq2seq例子为什么需要bucket?贾杨清的回答。
另外关于dynamic_rnn和普通的rnn区别可见另外一个
知乎问题: tensor flow dynamic_rnn 与rnn有啥区别?

总结一下做bi-directional LSTM时,tensorflow的关键函数bidirectional_dynamic_rnn function. 此函数的输入输出特性。

sample code:

# Create input data

X = np.random.randn(2,10,8)

# The second example is of length 6 

X[1,6:] = 0

X_lengths[10,6]

cell = tf.nn.rnn_cell.LSTMCell(num_units=20,state_is_tuple=True)

outputs,states = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell,cell_bw=cell,
dtype=tf.float64,sequence_length=X_lengths,inputs=X)

output_fw, output_bw = outputs
states_fw, states_bw = states
  1. sequence_length参数的提供。这个参数是一个向量,给的是Input里面,每一个batch中,每一个Data的第一维。比如上面Input有两个batch,每一个是(10, 8) 的dimension,于是sequence_length就可以是[10, 10]。这个参数很重要,给的不对就会出错。


LSTMStateTuple, LSTM state tuple是tensorflow用来表示state的一种格式定义,tensorflow的state_size, zero_state, and output state都是用LSTMStateTuple表示的。LSTMStateTuple包括(c, h) => c 是hidden state,h是output state;

就上面的例子,我们来print state_fw(forwarding cell的state)的c, and h

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer());
    states_shape = tf.shape(states);
    print(states_shape.eval());
    c, h = states_fw;
    print('print c');
    print(sess.run(c));
    print('print c1');
    print(sess.run(c[1]));

c的输出

[[ 0.10020639  0.04123133  0.1607345   0.04516676  0.14993827  0.04191888
   0.37377536  0.155445    0.0250121   0.19496255  0.14872166 -0.27649689
  -0.04177214 -0.64201794  0.33243907  0.19359499 -0.36642676  0.21112514
  -0.22745071 -0.26275169]
 [-0.11904546  0.31178861 -0.07443244  0.09425232  0.49845703 -0.03854894
  -0.19182923  0.21352997 -0.46194925  0.16779708 -0.33193142 -0.19566778
  -0.06612427 -0.11908099 -0.00726331  0.33335469  0.01462403 -0.58337865
  -0.07321394 -0.05640467]]

h的输出

[[ 0.15347487 -0.11674097 -0.09588729  0.00890338  0.12735097  0.09076786
   0.06652527  0.00207616  0.0500825   0.12345199 -0.01998148  0.08035006
   0.13353144 -0.01102215 -0.09175959  0.1455235   0.11431857  0.15356262
  -0.0725327  -0.03418285]
 [ 0.15409209  0.09243608 -0.05506054 -0.07781891  0.0971408   0.06453969
   0.03290611 -0.09163908  0.01231368  0.15045137  0.06517371  0.1267243
   0.02084229 -0.16214882 -0.20116859  0.24669899  0.04287307  0.04801212
  -0.02658258 -0.10132215]]

c, h都是(2, 20),2是因为有两个batch input,20是因为cell unit is 20. 所以每对于一个batch input,lstm 都给一个forward state,forwarding state分成c (hidden state), 和h (output state). 然后分别是unit size的vector。backward state也是一样。

  1. Output
    Output也分为forward cell和backward_cell的,我们就看forward cell。
[[[-1.06669103e-01 -3.88756185e-02  2.60744107e-02  2.43657585e-03
    9.88481327e-02 -1.15583728e-01 -1.56563918e-01  1.58782306e-01
    4.85588806e-02 -5.77454750e-02  3.36419355e-02 -1.98287352e-02
   -1.10931715e-01  3.82711717e-02 -9.20855234e-02 -9.49620258e-02
    6.52089598e-02 -3.53805118e-02  1.30414250e-03  6.59167148e-02]
  [ 1.59046099e-02  4.41911904e-02 -1.85633770e-02 -1.02646872e-01
    5.67889560e-02 -1.80578232e-01  1.16289962e-02  3.43832302e-02
   -9.97417632e-02  8.52767259e-03  4.19908236e-02 -2.92100595e-02
   -7.40990631e-03 -5.57406104e-02 -9.63055482e-03  6.64506140e-02
   -1.60809539e-01 -1.13207020e-01  6.67947362e-02 -1.01494835e-01]
  [ 1.29653547e-01  1.64123612e-01  1.14319183e-01 -3.13978750e-01
    3.21875813e-02 -1.20610558e-01  1.65674751e-01  9.67606455e-02
   -7.53250048e-02  1.64018976e-01  8.14044036e-02 -2.66330649e-01
   -8.32540716e-02 -2.30083245e-01 -3.66429645e-02  2.69508256e-01
   -3.16302908e-01 -7.56739776e-02 -7.03734257e-02 -7.27750202e-02]
  [ 2.10778175e-01  1.29119904e-01 -5.28395789e-02 -1.56958078e-01
   -5.97345897e-02  1.20157188e-01  2.17347619e-01  5.99727875e-02
   -1.64570565e-01  8.08612044e-02  2.07909278e-02 -2.07703283e-01
   -1.63712849e-02 -2.01749788e-01  1.14587708e-01  2.06933175e-01
   -2.55334961e-01 -6.03161794e-02 -1.97578049e-01 -1.72242306e-01]
  [ 3.00083183e-01  1.51446013e-01  4.82394269e-02 -3.39348109e-01
   -3.38322196e-01 -2.43085569e-01  1.82949930e-01  2.40813002e-01
   -2.80735872e-01  2.41043685e-01 -9.54472693e-02 -2.93748317e-01
    5.88980390e-02 -1.35271864e-01  2.89699583e-01  8.75721403e-02
   -3.55877522e-01 -1.33211501e-01 -7.57881317e-02 -2.64836250e-01]
  [ 2.41987682e-01  1.47695318e-02 -5.03866150e-02 -1.44482469e-01
   -1.44640164e-02 -6.75506747e-02  2.32746312e-01  1.47519780e-01
   -6.56321553e-02  1.60696093e-01  5.45594683e-03 -1.88477690e-01
    5.45150185e-02 -1.77628408e-01  1.05268972e-01  1.41610215e-01
   -1.60580096e-01 -7.24836242e-03 -1.00759851e-01 -1.47514166e-01]
  [ 2.12808506e-01 -3.27166227e-02  4.31225953e-03 -1.02816763e-01
    3.25901490e-03  3.66293909e-02  1.54212310e-01  1.69800784e-01
    1.10284434e-02  1.74149175e-01  3.41514708e-02 -1.91123912e-01
    3.02006378e-02 -1.99656590e-01  2.26262571e-02  2.52412919e-01
   -7.68398775e-02  5.94911484e-02 -1.31153846e-01 -1.20808211e-01]
  [ 3.81763504e-01  2.68575596e-02 -1.08793781e-01 -8.00019483e-02
   -3.68294635e-02  1.71446728e-01  1.18992211e-01 -2.13071169e-02
   -4.61473814e-02  1.82351966e-01 -8.44481138e-02  1.19407754e-02
    4.08584125e-02 -1.80411471e-01 -7.74698125e-03  1.93662041e-01
   -1.20557645e-01 -2.10084183e-02 -2.27600119e-01 -1.63846952e-01]
  [ 1.64898924e-01  6.37549641e-02  1.66957306e-02 -1.59360332e-01
   -1.51426048e-01  1.28056643e-01  2.85171791e-01  3.50425360e-04
   -1.99342025e-02  1.89266634e-01 -6.53912049e-02 -6.65443778e-02
    7.44334992e-02  4.85016031e-02  9.49079271e-02  3.24700401e-01
   -1.00750007e-01  2.18841138e-02 -1.61316038e-01  1.52122726e-02]
  [ 1.53474875e-01 -1.16740969e-01 -9.58872865e-02  8.90338491e-03
    1.27350972e-01  9.07678551e-02  6.65252728e-02  2.07615713e-03
    5.00825032e-02  1.23451987e-01 -1.99814751e-02  8.03500562e-02
    1.33531444e-01 -1.10221475e-02 -9.17595887e-02  1.45523503e-01
    1.14318569e-01  1.53562622e-01 -7.25326998e-02 -3.41828502e-02]]

 [[ 2.24818909e-01  2.21235678e-02  1.01460432e-01 -1.41914365e-01
   -1.21404939e-01  7.36078879e-02  1.38471242e-01 -1.17533437e-01
   -5.21530141e-03  1.67706170e-01 -7.17727515e-02 -1.06750419e-01
    7.13189845e-02 -9.07184818e-02  2.11111214e-02  1.81368716e-01
   -1.46839530e-01  2.14554598e-02 -7.90004557e-02  8.87259097e-02]
  [ 2.11738134e-01  2.06868083e-02 -1.42999066e-01 -5.44685789e-02
   -6.31460261e-02  1.86872216e-01  1.22599483e-01 -1.82293974e-01
    8.76017957e-02  7.64068221e-02 -5.74839315e-02  8.27909362e-02
    5.49907143e-02 -8.06081683e-02 -4.65603130e-02  1.07644840e-01
   -8.45501653e-02 -4.02021538e-02 -8.38841808e-02  3.63420987e-02]
  [ 2.65380968e-01 -5.11942699e-03 -1.10961564e-02 -1.96348422e-01
   -1.11433399e-01  1.53275799e-02  6.00570999e-02 -2.05297778e-01
    8.52545915e-02  1.97091206e-01 -7.42037228e-02  1.25797496e-01
    1.08283714e-01 -1.29675158e-01 -1.32684022e-01  1.19353210e-01
   -1.22913400e-01 -1.09450277e-01 -1.97762286e-02  2.60753532e-02]
  [ 2.29800970e-01  6.64311636e-02 -3.45172340e-02 -1.56474836e-01
   -4.81899131e-02  9.00044045e-02  8.26513916e-02 -1.33626283e-01
    1.37496640e-01  1.72760619e-01  9.74954132e-03  2.40818003e-02
    1.28755599e-02 -2.39148477e-01 -2.11945339e-01  1.92631382e-01
   -1.23300797e-01 -1.74345945e-02 -7.96618285e-02 -6.94683079e-03]
  [ 1.95140846e-01  1.06431901e-01 -9.20244228e-02 -2.09311995e-01
   -5.64830252e-03  9.53098517e-02  4.49136154e-02 -1.55642596e-01
   -4.00256764e-02  1.03820451e-01 -9.09035922e-02  1.30894101e-01
   -1.71891357e-02 -8.76164608e-02 -8.98778574e-02  7.59155122e-02
   -7.54771617e-02 -1.34889843e-01 -4.58820217e-02 -7.81068266e-02]
  [ 1.54092089e-01  9.24360827e-02 -5.50605385e-02 -7.78189060e-02
    9.71408042e-02  6.45396884e-02  3.29061070e-02 -9.16390804e-02
    1.23136831e-02  1.50451370e-01  6.51737096e-02  1.26724300e-01
    2.08422868e-02 -1.62148824e-01 -2.01168589e-01  2.46698990e-01
    4.28730731e-02  4.80121224e-02 -2.65825827e-02 -1.01322148e-01]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]]

output_fw的size是(2, 10, 20)。所以对于一个batch中的每一个input vector (本例中每一个batch是10 × 8,所以有10个vector),lstm都输出一个20的vector(20 is unit size)。
然后我们可以发现,之前state中的h,就是output中的最后一组vector!第二组h不是0, 是因为sequance_length参数 (也就是走了6个lstm就结束了)!

所以如果只需要take lstm中的final output,而不在乎中间过程(比如建一个classifier,而不是seq2seq)。直接take state output就可以了。


作者:stepsma
链接:https://www.jianshu.com/p/3540c6711d4f
來源:简书
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

你可能感兴趣的:(深度学习)