pytorch卷积神经网络实现手写数字识别

同上一篇全连接神经网络实现手写数字识别,此文记录了直观测试模式的代码。

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import cv2
from torch.autograd import Variable

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 设置超参数
num_epochs = 5
output_size = 10
batch_size = 100
learning_rate = 0.001

# MNIST 数据集下载
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

#  数据集加载
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

#2个卷积层的神经网络
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),   #输入1通道,输出16通道,其实代表卷积核的个数为16
            nn.BatchNorm2d(16),                                     #输入1通道,输出16通道,其实代表卷积核的个数为16
            nn.ReLU(),                                              #激励函数处理
            nn.MaxPool2d(kernel_size=2, stride=2))                  #最大池化,降采样   2x2 步长为2
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, output_size)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)    #将输出7*7*32拉成一个张量,size(0),返回行数,view(行数,-1),reshape成多少行数,列数模糊控制不管。
        out = self.fc(out)
        return out

model = ConvNet().to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

#训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# 测试模型
model.eval()    #把模型设置成验证模式
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  ##data是一个以两个张量为元素的列表
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# 保存模型
torch.save(model.state_dict(), 'model.pkl')
#
X_test, y_test = next(iter(test_loader))
inputs = Variable(X_test)
pred = model(inputs)
_, pred = torch.max(pred, 1)

print("Predict Label is:", (i for i in pred))
print("Real Label is :", [i for i in y_test])

img = torchvision.utils.make_grid(X_test)
img = img.numpy().transpose(1, 2, 0)

std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

pytorch卷积神经网络实现手写数字识别_第1张图片

Test Accuracy of the model on the 10000 test images: 99.01 %
Predict Label is: <generator object <genexpr> at 0x000002A02B024138>
Real Label is : [tensor(7), tensor(2), tensor(1), tensor(0), tensor(4), tensor(1), tensor(4), tensor(9), tensor(5), tensor(9), tensor(0), tensor(6), tensor(9), tensor(0), tensor(1), tensor(5), tensor(9), tensor(7), tensor(3), tensor(4), tensor(9), tensor(6), tensor(6), tensor(5), tensor(4), tensor(0), tensor(7), tensor(4), tensor(0), tensor(1), tensor(3), tensor(1), tensor(3), tensor(4), tensor(7), tensor(2), tensor(7), tensor(1), tensor(2), tensor(1), tensor(1), tensor(7), tensor(4), tensor(2), tensor(3), tensor(5), tensor(1), tensor(2), tensor(4), tensor(4), tensor(6), tensor(3), tensor(5), tensor(5), tensor(6), tensor(0), tensor(4), tensor(1), tensor(9), tensor(5), tensor(7), tensor(8), tensor(9), tensor(3), tensor(7), tensor(4), tensor(6), tensor(4), tensor(3), tensor(0), tensor(7), tensor(0), tensor(2), tensor(9), tensor(1), tensor(7), tensor(3), tensor(2), tensor(9), tensor(7), tensor(7), tensor(6), tensor(2), tensor(7), tensor(8), tensor(4), tensor(7), tensor(3), tensor(6), tensor(1), tensor(3), tensor(6), tensor(9), tensor(3), tensor(1), tensor(4), tensor(1), tensor(7), tensor(6), tensor(9)]

你可能感兴趣的:(pytorch,基础)