pytorch LSTM_regression

batch之间传递state

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(900)

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.lstm = nn.LSTM(input_size=1,hidden_size=2,num_layers=1,batch_first=True)
        self.out = nn.Linear(2, 1)
    def forward(self, x,state):
        # x(batch,seq_len,input_size) h_state(n_layers,batch,hidden_size) r_out(batch,seq_len,hidden_size)
        r_out,state= self.lstm(x,state) # (batch,seq_len,hidden_size[1]) -> (batch,seq_len,hidden_size[2])
        outs = self.out(r_out) # (batch,seq_len,hidden_size[2]) -> (batch,seq_len,hidden_size[1])
        return outs,state
rnn = RNN()

optimizer = torch.optim.Adam(rnn.parameters(), lr=0.02) # 同时更新W_hh,W_ih
mse = nn.MSELoss()
state = (torch.randn(1, 1, 2), torch.randn(1, 1, 2))
for step in range(100):
    # 构建数据
    start, end = step * np.pi, (step+1)*np.pi
    steps = torch.linspace(start, end, 10)
    x = torch.sin(steps).unsqueeze(0).unsqueeze(2)
    y = torch.cos(steps).unsqueeze(0).unsqueeze(2)
    # 学习
    prediction,state= rnn(x,state)
    state = (state[0].detach(),state[1].detach())
    loss = mse(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # 可视化结果
    plt.plot(steps, y.numpy().flatten(), 'r-')
    plt.plot(steps, prediction.detach().numpy().flatten(), 'b-')
    plt.draw(); plt.pause(0.05)
plt.show()

batch之间不传递state

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1000)

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.lstm = nn.LSTM(input_size=1,hidden_size=2,num_layers=1,batch_first=True)
        self.out = nn.Linear(2, 1)
    def forward(self, x):
        # x(batch,seq_len,input_size) h_state(n_layers,batch,hidden_size) r_out(batch,seq_len,hidden_size)
        r_out,_= self.lstm(x,) # (batch,seq_len,hidden_size[1]) -> (batch,seq_len,hidden_size[2])
        outs = self.out(r_out) # (batch,seq_len,hidden_size[2]) -> (batch,seq_len,hidden_size[1])
        return outs
rnn = RNN()

optimizer = torch.optim.Adam(rnn.parameters(), lr=0.02) # 同时更新W_hh,W_ih
mse = nn.MSELoss()
for step in range(100):
    # 构建数据
    start, end = step * np.pi, (step+1)*np.pi
    steps = torch.linspace(start, end, 10)
    x = torch.sin(steps).unsqueeze(0).unsqueeze(2)
    y = torch.cos(steps).unsqueeze(0).unsqueeze(2)
    # 学习
    prediction = rnn(x)
    loss = mse(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # 可视化结果
    plt.plot(steps, y.numpy().flatten(), 'r-')
    plt.plot(steps, prediction.detach().numpy().flatten(), 'b-')
    plt.draw(); plt.pause(0.05)
plt.show()

https://zhuanlan.zhihu.com/p/94757947
https://discuss.pytorch.org/t/lstm-how-to-remember-hidden-and-cell-states-across-different-batches/11957
https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html

你可能感兴趣的:(Python,python)