一、代码
import torch
from torch import nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
EPOCH = 1
BATCH_SIZE = 64
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.001
train_data = datasets.MNIST(root='./mnist', train=True,
transform=transforms.ToTensor(),
download = True
)
test_data = datasets.MNIST(root='./mnist', train=False,
transform=transforms.ToTensor(),
download = True
)
test_x = test_data.test_data.type(torch.FloatTensor)[:2000] / 255.
test_y = test_data.test_labels.numpy()[:2000]
print(train_data.train_data.size())
print(train_data.train_labels.size())
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.show()
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.GRU(input_size=INPUT_SIZE, hidden_size=64, num_layers=1, batch_first=True)
self.out = nn.Linear(64, 10)
def forward(self, x):
r_out, _ = self.rnn(x)
out = self.out(r_out[:,-1,:])
return out
cuda = torch.device('cuda')
rnn = RNN()
rnn = rnn.cuda()
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader):
b_x = b_x.view(-1, 28, 28)
output = rnn(b_x.cuda())
loss = loss_func(output, b_y.cuda())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
test_output = rnn(test_x.cuda())
pred_y = torch.max(test_output, 1)[1].data.cpu().numpy()
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
print ('Epoch: {}, Step: {}, loss: {}, accuracy: {}'.format(epoch, step, loss, accuracy))
test_x = test_x.cuda()
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.cpu().numpy()
print('预测数字', pred_y)
print('实际数字', test_y[:10])
二、运行结果
Epoch: 0, Step: 0, loss: 2.3083088397979736, accuracy: 0.086
Epoch: 0, Step: 100, loss: 1.330346703529358, accuracy: 0.5135
Epoch: 0, Step: 200, loss: 0.956856369972229, accuracy: 0.6725
Epoch: 0, Step: 300, loss: 0.715281069278717, accuracy: 0.76
Epoch: 0, Step: 400, loss: 0.3689120411872864, accuracy: 0.792
Epoch: 0, Step: 500, loss: 0.4506811201572418, accuracy: 0.832
Epoch: 0, Step: 600, loss: 0.3698571026325226, accuracy: 0.855
Epoch: 0, Step: 700, loss: 0.4824417531490326, accuracy: 0.8655
Epoch: 0, Step: 800, loss: 0.5076820850372314, accuracy: 0.8745
Epoch: 0, Step: 900, loss: 0.2233055979013443, accuracy: 0.885
预测数字 [7 2 1 0 4 1 4 9 5 9]
实际数字 [7 2 1 0 4 1 4 9 5 9]