tensorflow系列:搭建LSTM

LSTM网络:

tensorflow系列:搭建LSTM_第1张图片

LSTM网络的权重设置:

tensorflow系列:搭建LSTM_第2张图片

LSTM函数参数的含义:

从图片和公式可知,LSTM单元有单个输入(Ct-1,ht-1,xt),三个输出(Ct,ht,ht)。

  1. 参数说明:
    num_units:int类型,LSTM单元中的神经元数量,即输出神经元数量 forget_bias:float类型,偏置增加了忘记门。从CudnnLSTM训练的检查点(checkpoin)恢复时,必须手动设置为0.0。
    state_is_tuple:如果为True,则接受和返回的状态是c_state和m_state的2-tuple;如果为False,则他们沿着列轴连接。后一种即将被弃用。
    activation:内部状态的激活函数。默认为tanh
    reuse:布尔类型,描述是否在现有范围中重用变量。如果不为True,并且现有范围已经具有给定变量,则会引发错误。
    name:String类型,层的名称。具有相同名称的层将共享权重,但为了避免错误,在这种情况下需要reuse=True.
    dtype:该层默认的数据类型。默认值为None表示使用第一个输入的类型。在call之前build被调用则需要该参数。

LSTM代码实例

import tensorflow as tf

output_dim=128
 
lstm=tf.nn.rnn_cell.BasicLSTMCell(output_dim)
 
batch_size=10 #批处理大小

embedding_dim=300 #词向量维度
 
inputs=tf.Variable(tf.random_normal([batch_size,embedding_dim]))
 
previous_state = (tf.random_normal(shape=(batch_size, output_dim)), tf.random_normal(shape=(batch_size, output_dim)))
 
output,(new_h, new_state)=lstm(inputs,previous_state)
 
print(output.shape) #(10, 128)
 
print(new_h.shape) #(10, 128)
 
print(new_state.shape) #(10, 128)

你可能感兴趣的:(tensorflow)