PyTorch实现MNIST手写数字识别

训练集和测试集的手写字都是黑底白字,但是自己测试的数字一般都是白底,测试效果很差,所以加了一行反转颜色代码

# 反转颜色
img = PIL.ImageOps.invert(img)

一、导入包

import torch
from torch.nn import Sequential
from matplotlib import pyplot as plt
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

二、加载训练集和测试集

# 从指定路径加载数据集,不下载,并将其转化为Tensor格式
train_dataset = datasets.MNIST(root='../mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=False)

test_dataset = datasets.MNIST(root='../mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=False)

三、定义数据包装器

# 每批的数据为64个,用于数据加载,shuffle为打乱数据
batch_size = 64
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader=DataLoader(dataset=test_dataset,
                       batch_size=batch_size,
                       shuffle=False)

四、展示数据集

# 该维度表示64张图片,每张的channel为1,宽高为28,28
for i, data in enumerate(train_loader):
    inputs, labels = data
    print(inputs.shape)
    print(labels.shape)
    break

# 随机绘图
idx = 555
# dataset支持下标索引,元素为target,features 标签,属性
img = test_dataset[idx][0].numpy()
# 索引features
plt.imshow(img[0], cmap='gray')
# train_dataset[][0]为图片信息,train_dataset[][1]为label
print("标签是:", test_dataset[idx][1])
plt.show()
PyTorch实现MNIST手写数字识别_第1张图片

五、神经网络

# nn.Module
class CNN(nn.Module):
    # 初始化相关网络层
    def __init__(self):
        # 初始化父类
        super(CNN, self).__init__()
        self.conv1 = Sequential(
            nn.Conv2d(1, 4, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv2 = Sequential(
            nn.Conv2d(4, 8, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc1 = Sequential(
            nn.Linear(392, 50),
            nn.Dropout(0.5),
            nn.ReLU()
        )
        self.fc2 = Sequential(
            nn.Linear(50, 10),
            nn.Softmax(dim=1)
        )
    # 定义传播函数
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
# 定义模型
model = CNN()
if torch.cuda.is_available():
    model = model.cuda()

六、学习率、损失函数和优化器

# 学习率
learning_rate = 0.01
# 定义损失函数:交叉熵
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.cuda()
# 定义优化器
optimizer = optim.SGD(model.parameters(), learning_rate)

七、训练数据和测试,保存

def train():
    model.train()
    for i, data in enumerate(train_loader):
        input, target = data
        input = input.cuda()
        target = target.cuda()
        # 预测结果
        output = model(input)
        # 计算损失
        loss = loss_fn(output, target)
        # 梯度清0
        optimizer.zero_grad()
        # 反向传播得到每一个参数节点的梯度
        loss.backward()
        # 对参数进行优化
        optimizer.step()

def test():
    model.eval()
    correct = 0
    for i, data in enumerate(train_loader):
        input, target = data
        input = input.cuda()
        target = target.cuda()
        output = model(input)
        _, pred = torch.max(output, 1)
        correct += (pred == target).sum()
    print('准确率:', correct.item()/len(train_dataset))
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            input, target = data
            input = input.cuda()
            target = target.cuda()
            output = model(input)
            _, pred = torch.max(output, 1)
            correct += (pred == target).sum()
        print('准确率:', correct.item()/len(test_dataset))

# 训练15次模型,不断修改权重
for epoch in range(15):
    print('epoch:', epoch)
    train()
    test()

# 保存模型的参数
torch.save(model.state_dict(), r'D:\Python\ppp\first\hand_writer_0-9\15.pth')
PyTorch实现MNIST手写数字识别_第2张图片
PyTorch实现MNIST手写数字识别_第3张图片
PyTorch实现MNIST手写数字识别_第4张图片

八、模型测试

import PIL.ImageOps
import torch
from PIL import Image
from torch import nn
from torch.nn import Sequential
from torchvision.transforms import transforms

class CNN(nn.Module):
    # 初始化相关网络层
    def __init__(self):
        # 初始化父类
        super(CNN, self).__init__()
        self.conv1 = Sequential(
            nn.Conv2d(1, 4, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv2 = Sequential(
            nn.Conv2d(4, 8, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc1 = Sequential(
            nn.Linear(392, 50),
            nn.Dropout(0.5),
            nn.ReLU()
        )
        self.fc2 = Sequential(
            nn.Linear(50, 10),
            nn.Softmax(dim=1)
        )
# 定义传播函数
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
# 重新定义模型,加载保存的参数
model = CNN()
model.load_state_dict(torch.load(r'D:\Python\ppp\first\hand_writer_0-9\15.pth'))


# 读取图片
img = Image.open(r'D:\Python\ppp\first\number_3.png')
# 改变图片大小
img = img.resize((28, 28))
# 图片灰度化
img = img.convert('L')
# 反转颜色
img = PIL.ImageOps.invert(img)
# 将图片转化为Tensor
transform = transforms.Compose([transforms.ToTensor()])
# 将图片增加一个维度,卷积层需要4维
img = transform(img).unsqueeze(0)
# 模型转化为测试类型
model.eval()
with torch.no_grad():
    output = model(img)
    _, pred = torch.max(output, 1)
print("数字:", pred.item())
PyTorch实现MNIST手写数字识别_第5张图片

你可能感兴趣的:(pytorch)