多层RNN的定义与理解

代码:


import tensorflow as tf
import numpy as np

def get_a_cell():
    ### 128 是 状态矢量的长度
    return tf.nn.rnn_cell.BasicRNNCell(num_units=128)
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])

print(cell.state_size)

## 32 是 batch_size ,100 是 inputs 矢量的长度
inputs = tf.placeholder(np.float32,shape=(32,100))
h0 = cell.zero_state(32,np.float32) ## 通过zero_state得到一个全0的初始状态(只需给出状态的矢量长度即可,因为状态肯定是矢量)

output,h1 = cell(inputs,h0)
print(output)
print(h1)

 

输出:

(128, 128, 128)
Tensor("multi_rnn_cell/cell_2/basic_rnn_cell/Tanh:0", shape=(32, 128), dtype=float32)
(
, 
, 

)


 

 

你可能感兴趣的:(tensorflow)