循环神经网络让神经网络有了记忆, 对于序列型的数据,循环神经网络能达到更好的效果.接着我将实战分析手写数字的 RNN分类
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)
TIME_STEP = 10
INPUT_SIZE = 1
LR = 0.02
我们要用到的数据就是这样的一些数据, 用 sin 的曲线预测出 cos 的曲线,也即用sin拟合cos
steps = np.linspace(0, np.pi*2, 100, dtype=np.float32) # float32 用于之后转化为torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)
plt.plot(steps, y_np, 'g-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=32,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(32, 1)
def forward(self, x, h_state):
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state) #传入当前输出和上一个隐状态
print(r_out.size(1))
outs = [] # 用于保存所有的预测
for time_step in range(r_out.size(1)):
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state
rnn = RNN()
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.MSELoss()
h_state = None # 初始化第一个time step的上一个状态
plt.figure(1, figsize=(12, 5))
plt.ion()
for step in range(100):
start, end = step * np.pi, (step+1)*np.pi
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32, endpoint=False)
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
prediction, h_state = rnn(x, h_state)
h_state = h_state.data # repack the hidden state, break the connection from last iteration
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 作图
plt.plot(steps, y_np.flatten(), 'g-')
plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
plt.draw()
plt.pause(0.05)
plt.ioff()
plt.show()
效果:
由于图比较多:这里仅贴上开始的2个迭代和最后的2个迭代图片。
开始的2个迭代
最后的2个迭代
可以看出, 我们使用 x 作为输入的 sin 值, 然后 y 作为想要拟合的输出, cos 值. 因为他们两条曲线是存在某种关系的, 所以我们就能用 sin 来预测 cos. rnn 会理解他们的关系, 并用里面的参数分析出来这个时刻 sin 曲线上的点如何对应上 cos 曲线上的点.