一、总体介绍
对于sin函数的输入样本进行训练,得到sin函数相关的RNN网络架构,预测sin函数的趋势
二、相关参数设置
HIDDEN_SIZE = 30 #隐藏层层数
NUM_LAYERS = 2 #RNN层数
TIMESTEPS = 10 #RNN的截断长度 利用10个点的信息 预测第11个点的信息
TRAIN_STEP = 10000 #训练轮数
TRAIN_EXAMPLES = 10000 #训练数据的个数
BATCH_SIZE = 32 #训练的batch
TEST_EXAMPLES = 10000 #测试数据的个数
SAMPLE_GAP = 0.01
LEARNNING_RATE = 0.1 #学习率
三、模拟数据产生函数
def generate_data(seq):
x = []
y = []
for i in range(len(seq) - TIMESTEPS):
x.append([seq[i:i+TIMESTEPS]])
y.append([seq[i+TIMESTEPS]])
return np.array(x,dtype = np.float32), np.array(y,dtype = np.float32)
四、构建基础LSTM单元(单层)函数
def LstmCell():
lstm_cell = rnn.BasicLSTMCell(HIDDEN_SIZE,state_is_tuple=True)
return lstm_cell
五、构建RNN–多层LSTM模型函数
5.1 构建多层LSTM模型
cell = rnn.MultiRNNCell([LstmCell() for _ in range(NUM_LAYERS)])
5.2 训练rnn,output为输出的结果,_ 返回的是最终的状态
output,_ = tf.nn.dynamic_rnn(cell,x,dtype=tf.float32)
5.3将output 重塑成 n×HIDDEN_SIZE 的矩阵,即每行属于同一层
output = tf.reshape(output,[-1, HIDDEN_SIZE])
5.4 创建一个全连接层,1 表示输出的维度为1,即做的是 n×HIDDEN_SIZE 的矩阵 和 HIDDEN_SIZE×1的矩阵相乘。None指的是不使用激活函数。
predictions = tf.contrib.layers.fully_connected(output, 1, None)
5.5取出最后一个数 即最终的输出结果
labels = tf.reshape(y, [-1]) #期望值
predictions = tf.reshape(predictions, [-1]) #预测值
5.6 得到均方损失
loss = tf.losses.mean_squared_error(predictions, labels)
5.7 优化函数
train_op = tf.contrib.layers.optimize_loss(
loss,
tf.contrib.framework.get_global_step(),
optimizer = "Adagrad",
learning_rate = LEARNNING_RATE)
六、搭建总体框架
6.1 建立多层RNN
regressor = learn.Estimator(model_fn = lstm_model)
这里利用到了learn的Estimator函数 传入参数为lstm_model构建RNN–多层LSTM模型的函数
6.2 产生数据
test_start = TRAIN_EXAMPLES * SAMPLE_GAP
test_end = (TRAIN_EXAMPLES + TEST_EXAMPLES)* SAMPLE_GAP
#产生仿真数据x y 用于训练
train_x,train_y = generate_data(np.sin(np.linspace(0, test_start,
TRAIN_EXAMPLES,dtype = np.float32)))
#产生仿真数据x y 用于测试
test_x,test_y = generate_data(np.sin(np.linspace(test_start,test_end,
TEST_EXAMPLES, dtype = np.float32)))
6.3 调用fit函数训练模型
regressor.fit(train_x,train_y,batch_size = BATCH_SIZE,steps = TRAIN_STEP)
6.4 利用训练好的模型进行预测 计算均方差
predicted = [[pred] for pred in regressor.predict(test_x)]
rmse = np.sqrt((predicted-test_y)**2).mean(axis=0)
6.5 绘图
fig = plt.figure()
plot_predicted = plt.plot(predicted,label = "predicted",color='red')
plot_test = plt.plot(test_y,label = "real_sin",color='blue')
plt.legend = ([plot_predicted,plot_test],['predicted','real_sin'])
plt.show()
七、结果
最近新开的公众号,文章正在一篇篇的更新,
公众号名称:玩转电子世界。
各位朋友有什么问题了可以直接在上面提问,我会一一进行解答的。
跟着阳光非宅男,一步步走进电子的世界。
关注之后回复 资料下载 关键词可以获得免费的视频学习资料下载~~~~!!
源码下载:http://download.csdn.net/download/yunge812/10269571