PyTorch学习笔记四——LSTM手写数字识别

本文学习自莫烦教程

import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
%matplotlib inline
torch.manual_seed(1)

EPOCH = 1        # 迭代轮数
BATCH_SIZE = 64  # 批量大小
TIME_STEP = 28   # RNN的时间步/图片的高
INPUT_SIZE = 28  # RNN是输入size/图片的宽
LR = 0.01        # 学习率

train_data = dsets.MNIST(
    root = './MNIST/',
    train = True,
    transform = transforms.ToTensor,
    download = False  # 若要下载则改为True
)
train_loader = DataLoader(train_data, BATCH_SIZE, shuffle=True)
# 准备验证集
test_data = dsets.MNIST(root='./MNIST/', train=False, transform = transforms.ToTensor)
test_x = Variable(test_data.test_data, volatile=True).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels.numpy().squeeze()[:2000]

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=64,
            num_layers=1,
            batch_first=True
        )
        self.out = nn.Linear(64, 10)
    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, hidden_size)
        # h_n shape (n_layers, batch, hidden_size)
        # h_c shape (n_layers, batch, hidden_size)
        r_out, (h_n, h_c) = self.rnn(x, None)  # None代表将隐藏状态初始化为0
        o = self.out(r_out[:, -1, :])  # 选择最后一个时间步的r_out作为输出
        return o
rnn = RNN()
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
# 训练模型
for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):
        b_x = Variable(x.view(-1, 28, 28))
        b_y = Variable(y)
        output = rnn(b_x)
        loss = loss_func(output, b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step % 50 == 0:
            test_output = rnn(test_x)
            pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
            acc = (pred_y == test_y).sum().item() / float(test_y.size(0))
            print('Epoch: ', epoch, '| train loss: %.4f'%loss, '| test acc: %.2f'%acc)       

训练模型和CNN的训练类似,这里主要说一下LSTM的结构。
PyTorch学习笔记四——LSTM手写数字识别_第1张图片
结合本文的实例,我对各个参数进行说明和计算,输入 X t X_t Xt的size是(64, 28, 28),对应(batch_size, time_step, input_size),输入有28个时间步,即有28个图1 的单元,批量大小是64张图片,先分析一个单元。
在一个单元内,假设此时的时间步是t,输入 X t X_t Xt的size是(64, 1, 28),28是input_size,t-1时刻的隐藏状态 H t − 1 H_{t-1} Ht1的size是(64, 1, 64),对应(batch_size, hidden_size),t-1时刻的记忆细胞的size也是(64, 1, 64),遗忘门的输出是,
f = s i g m o i d ( W f 1 X t + W f 2 H t − 1 ) f=sigmoid(W_{f1}X_t+W_{f2}H_{t-1}) f=sigmoid(Wf1Xt+Wf2Ht1)
输入门的输出,
i = s i g m o i d ( W i 1 X t + W i 2 H t − 1 ) i=sigmoid(W_{i1}X_t+W_{i2}H_{t-1}) i=sigmoid(Wi1Xt+Wi2Ht1)
候选记忆细胞,
z = t a n h ( W z 1 X t + W z 2 H t − 1 ) z=tanh(W_{z1}X_t+W_{z2}H_{t-1}) z=tanh(Wz1Xt+Wz2Ht1)
得到新的状态 c t c_t ct,
c t = f ⋅ c t − 1 + i ⋅ z c_t=f\cdot c_{t-1}+i\cdot z ct=fct1+iz
输出门,
o = s i g m o i d ( W o 1 X t + W o 2 H t − 1 ) o=sigmoid(W_{o1}X_t+W_{o2}H_{t-1}) o=sigmoid(Wo1Xt+Wo2Ht1)
输出t时刻的隐藏状态,该状态在最终时刻是LSTM系统的输出,
H t = o ⋅ t a n h ( c t ) H_t=o\cdot tanh(c_t) Ht=otanh(ct)
经过以上运算, c t c_t ct H t H_t Ht的size依然是(64, 1, 64),计算完t时刻,就要进行t+1时刻的计算,直到所有时间步全部计算完成,所有最后的输出size是(64, 28, 64)
PyTorch学习笔记四——LSTM手写数字识别_第2张图片
图2是LSTM的全部时间计算流程。本程序中指定了num_layers=1,即图2中的depth是1。
然后通过一个全连接层,映射到10个类别中。

你可能感兴趣的:(PyTorch学习笔记四——LSTM手写数字识别)