batch之间传递state
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(900)
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.lstm = nn.LSTM(input_size=1,hidden_size=2,num_layers=1,batch_first=True)
self.out = nn.Linear(2, 1)
def forward(self, x,state):
# x(batch,seq_len,input_size) h_state(n_layers,batch,hidden_size) r_out(batch,seq_len,hidden_size)
r_out,state= self.lstm(x,state) # (batch,seq_len,hidden_size[1]) -> (batch,seq_len,hidden_size[2])
outs = self.out(r_out) # (batch,seq_len,hidden_size[2]) -> (batch,seq_len,hidden_size[1])
return outs,state
rnn = RNN()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.02) # 同时更新W_hh,W_ih
mse = nn.MSELoss()
state = (torch.randn(1, 1, 2), torch.randn(1, 1, 2))
for step in range(100):
# 构建数据
start, end = step * np.pi, (step+1)*np.pi
steps = torch.linspace(start, end, 10)
x = torch.sin(steps).unsqueeze(0).unsqueeze(2)
y = torch.cos(steps).unsqueeze(0).unsqueeze(2)
# 学习
prediction,state= rnn(x,state)
state = (state[0].detach(),state[1].detach())
loss = mse(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 可视化结果
plt.plot(steps, y.numpy().flatten(), 'r-')
plt.plot(steps, prediction.detach().numpy().flatten(), 'b-')
plt.draw(); plt.pause(0.05)
plt.show()
batch之间不传递state
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1000)
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.lstm = nn.LSTM(input_size=1,hidden_size=2,num_layers=1,batch_first=True)
self.out = nn.Linear(2, 1)
def forward(self, x):
# x(batch,seq_len,input_size) h_state(n_layers,batch,hidden_size) r_out(batch,seq_len,hidden_size)
r_out,_= self.lstm(x,) # (batch,seq_len,hidden_size[1]) -> (batch,seq_len,hidden_size[2])
outs = self.out(r_out) # (batch,seq_len,hidden_size[2]) -> (batch,seq_len,hidden_size[1])
return outs
rnn = RNN()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.02) # 同时更新W_hh,W_ih
mse = nn.MSELoss()
for step in range(100):
# 构建数据
start, end = step * np.pi, (step+1)*np.pi
steps = torch.linspace(start, end, 10)
x = torch.sin(steps).unsqueeze(0).unsqueeze(2)
y = torch.cos(steps).unsqueeze(0).unsqueeze(2)
# 学习
prediction = rnn(x)
loss = mse(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 可视化结果
plt.plot(steps, y.numpy().flatten(), 'r-')
plt.plot(steps, prediction.detach().numpy().flatten(), 'b-')
plt.draw(); plt.pause(0.05)
plt.show()
https://zhuanlan.zhihu.com/p/94757947
https://discuss.pytorch.org/t/lstm-how-to-remember-hidden-and-cell-states-across-different-batches/11957
https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html