RNN具体模型请参考这里。
基于RNN的正弦波形下一段波形的预测代码如下:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
num_time_steps = 50
input_size = 1
hidden_size = 16
num_layers = 1
output_size = 1
# 设计模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.rnn = nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True
)
self.linear = nn.Linear(hidden_size,output_size)
def forward(self,x,hidden_prev):
out,hidden_prev = self.rnn(x,hidden_prev)
# [1,seq,h] => [seq,h]
out = out.view(-1,hidden_size)
out = self.linear(out) # [seq,h] => [seq,1]
out = out.unsqueeze(dim=0) # => [1,seq,1]
return out,hidden_prev
#
model = Net()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
# 训练
hidden_prev = torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev = model(x,hidden_prev)
hidden_prev = hidden_prev.detach() # hidden_prev不参与梯度的计算
loss = criterion(output,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if iter%100 == 99:
print('Iteration: {0}, Loss: {1}'.format(iter+1,loss.item()))
# 测试
start = np.random.randint(3, size=1)[0]
time_steps = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_steps)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
predictions = []
input = x[:,0,:]
for _ in range(x.shape[1]):
input = input.view(1,1,1)
(pred,hidden_prev) = model(input,hidden_prev)
input = pred
predictions.append(pred.detach().numpy().ravel()[0])
# 画图
x = x.data.numpy().ravel()
y = y.data.numpy()
plt.plot(time_steps[:-1],x.ravel(),color="blue",linewidth=2,marker="o",markersize=8,label='True')
plt.scatter(time_steps[1:],predictions,c='r',label='Predict')
plt.legend(loc='lower left')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
运行结果:
Iteration: 100, Loss: 0.08230462670326233
Iteration: 200, Loss: 0.03460197150707245
Iteration: 300, Loss: 0.026081932708621025
Iteration: 400, Loss: 0.016143618151545525
Iteration: 500, Loss: 0.0021185846999287605
Iteration: 600, Loss: 0.008243701420724392
Iteration: 700, Loss: 0.005049371160566807
Iteration: 800, Loss: 0.0011325869709253311
Iteration: 900, Loss: 0.0044587296433746815
Iteration: 1000, Loss: 0.003480890765786171
Iteration: 1100, Loss: 0.0004979512887075543
Iteration: 1200, Loss: 0.0015343809500336647
Iteration: 1300, Loss: 0.001767122419551015
Iteration: 1400, Loss: 0.007279172074049711
Iteration: 1500, Loss: 0.006428186781704426
Iteration: 1600, Loss: 0.005187650211155415
Iteration: 1700, Loss: 0.0006386577151715755
Iteration: 1800, Loss: 0.005688886158168316
Iteration: 1900, Loss: 0.0015109286177903414
Iteration: 2000, Loss: 0.007405400741845369
Iteration: 2100, Loss: 0.00040985929081216455
Iteration: 2200, Loss: 0.0008332210127264261
Iteration: 2300, Loss: 0.0018269600113853812
Iteration: 2400, Loss: 0.00014836857735645026
Iteration: 2500, Loss: 0.0007819193997420371
Iteration: 2600, Loss: 0.004233400337398052
Iteration: 2700, Loss: 0.004981044214218855
Iteration: 2800, Loss: 0.0009075272246263921
Iteration: 2900, Loss: 0.0028365375474095345
Iteration: 3000, Loss: 0.0008490863256156445
Iteration: 3100, Loss: 0.0019383304752409458
Iteration: 3200, Loss: 0.0040445877239108086
Iteration: 3300, Loss: 0.003862386103719473
Iteration: 3400, Loss: 0.0005976956454105675
Iteration: 3500, Loss: 0.0004794742853846401
Iteration: 3600, Loss: 0.0030074138194322586
Iteration: 3700, Loss: 0.00015448348131030798
Iteration: 3800, Loss: 0.002118135569617152
Iteration: 3900, Loss: 0.0019309526542201638
Iteration: 4000, Loss: 0.0010463208891451359
Iteration: 4100, Loss: 0.0017843478126451373
Iteration: 4200, Loss: 0.0017833326710388064
Iteration: 4300, Loss: 0.0016676844097673893
Iteration: 4400, Loss: 0.0015999673632904887
Iteration: 4500, Loss: 0.00022673775674775243
Iteration: 4600, Loss: 0.0005361718940548599
Iteration: 4700, Loss: 0.0001656347158132121
Iteration: 4800, Loss: 0.0017306145746260881
Iteration: 4900, Loss: 0.002255344996228814
Iteration: 5000, Loss: 0.00025039984029717743
Iteration: 5100, Loss: 0.0008995591779239476
Iteration: 5200, Loss: 0.0003262818790972233
Iteration: 5300, Loss: 0.0016977142076939344
Iteration: 5400, Loss: 0.0009827131871134043
Iteration: 5500, Loss: 0.0017332301940768957
Iteration: 5600, Loss: 0.0014740482438355684
Iteration: 5700, Loss: 0.00046479827142320573
Iteration: 5800, Loss: 0.0005490729818120599
Iteration: 5900, Loss: 0.00042482122080400586
Iteration: 6000, Loss: 0.00046664406545460224
[1] Mikolov, Tomas & Karafiát, Martin & Burget, Lukas & Cernocký, Jan & Khudanpur, Sanjeev. Recurrent neural network based language model[C]. Proceedings of the 11th Annual Conference of the International Speech Communication Association, INTERSPEECH 2010. 2. 1045-1048.
[2] https://www.bilibili.com/video/BV1f34y1k7fi?p=89