莫烦 pytorch RNN 回归

我的视频学习笔记

视频地址:https://www.bilibili.com/video/av15997678?p=23

import torch
from torch import nn
import numpy as np
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Hyper parameters
BATCH_SIZE = 64
EPOCH = 1
TIME_STEP = 28  # 考虑多少个时间点的数据
INPUT_SIZE = 1  # 每个时间点给RNN多少个数据点
LR = 0.01

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.RNN(
            input_size=INPUT_SIZE,
            hidden_size=32,
            num_layers=1,
            batch_first=True,
        )
        self.out = nn.Linear(32, 1)

    def forward(self, x, h_state):
        # x(batch, time_step, input_size)
        # h_state(n_layers, batch, hidden_size)
        # r_out(batch, time_step, output_size = hidden_size)
        r_out, h_state = self.rnn(x, h_state)
        outs = []
        for time_step in range(r_out.size(1)):  # size是tensor的形状是一个数组,size(1)就是里面的第二个值域,
            # 就是time_step的值的个数 即第二个维度的大小
            outs.append(self.out(r_out[:, time_step, :]))
        return torch.stack(outs, dim=1), h_state


rnn = RNN()
print(rnn)

optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.MSELoss()

h_state = None

plt.figure(1, figsize=(12, 5))
plt.ion()

for step in range(50):
    start, end = step * np.pi, (step + 1) * np.pi
    # use sin pre cos
    steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
    x_np = np.sin(steps)
    y_np = np.cos(steps)

    x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])  # shape(batch, time_step, input_size)
    y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])

    prediction, h_state = rnn(x, h_state)
    h_state = h_state.data  # !!! this step is important

    loss = loss_func(prediction, y)
    optimizer.zero_grad()  # clear gradient for next train
    loss.backward()  # back propagation, compute gradient
    optimizer.step()

    # plot
    plt.plot(steps, y_np.flatten(), 'r-')
    plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
    plt.draw()
    plt.pause(0.5)

plt.ioff()
plt.show()

你可能感兴趣的:(代码小笔记)