mnist手写数字识别,dnn实现代码解读

mnist手写数字识别,dnn实现代码解读

  • 代码及注释?
  • 模型结构
  • 相关问题
    • net.train() 和 net.eval()的作用?
    • 为什么是output.max(1)
    • optim.zero_grad()、pred=model(input)、loss=criterion(pred,tgt)、loss.backward()、optim.step()的作用

代码及注释?

# coding: utf-8

'''
通过dnn识别手写数字集
'''
import os
import sys
sys.path.append(os.path.abspath(
    os.path.dirname(os.path.abspath(__file__)) + os.path.sep + ".."))

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from config import MINIST_DATASET

#加载数据集--训练集
data_train = torchvision.datasets.MNIST(root=MINIST_DATASET,
                            transform=torchvision.transforms.ToTensor(),
                            train=True,
                            download=True)
#批量化处理数据集,一个批次64个数据--训练集
loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                batch_size=64,
                                                shuffle=True)
#加载数据集--测试集
data_test = torchvision.datasets.MNIST(root=MINIST_DATASET,
                            transform=torchvision.transforms.ToTensor(),
                            train=False,
                            download=True)
#批量化处理数据集,一个批次64个数据--测试集
loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                                batch_size=64,
                                                shuffle=True)

#定义一个网络,继承自pytorch.nn.module
class Net(nn.Module):
    #继承父亲初始化的方法,并定义两个全链接层,分别是输入784维、输出100维;输入100维输出10维
    def __init__(self):
        super(Net, self).__init__()
        self.line1 = nn.Linear(784, 100)
        self.line2 = nn.Linear(100, 10)
    #定义网络向前传播的结构
    def forward(self, x):
        #将图像的二维数据变换为一 维数据
        x = x.reshape(-1, 784)
        #将数据输入第一层神经网络,并将其输出通过激活函数激活
        x = F.relu(self.line1(x))
        #dropout层
        x = F.dropout(x, 0.2, training=self.training)
        #将数据输入第二层神经网络
        x = self.line2(x)
        #再通过激活函数
        x = F.softmax(x, dim=1)
        #模型输出一个10维的张量
        return x

#模型训练
def train_model(num, save_model=True):
    #网络实例化
    net = Net()
    #定义优化器,参数调整的方式
    optim = torch.optim.Adam(params=net.parameters())
    #定义损失函数
    loss_function = nn.CrossEntropyLoss()

    #开始训练,
    for epoch in range(num):
        #模型训练模式
        net.train()
        running_loss = 0
        #批次训练,每次只有64个图片参与计算
        for data, label in loader_train:
            #计算模型结果
            output = net(data)
            #计算损失值
            loss = loss_function(output, label)
            #梯度清空
            optim.zero_grad()
            #反向传播,计算每个参数的梯度
            loss.backward()
            #使用优化器调整参数
            optim.step()
            #计算一个epoch的损失值
            running_loss += loss.item()
        print("loss:{}".format(running_loss))

        #模型验证模式
        net.eval()
        test_correct = 0
        #批量化验证
        for data, label in loader_test:
            #使用模型计算结果
            output = net(data)
            #取出10维数据中最大值的索引,即位预测结果
            _, output = output.max(dim=1)
            #计算准确率
            test_correct += (label == output).sum().item()
        print("正确率:{}%".format(round(test_correct/100.0, 2)))

    if save_model:
        # 保存模型
        torch.save(net, os.path.join(base_dir, 'test.pth.tar'))


def use_save_model():
    # 加载模型
    net = torch.load(os.path.join(base_dir, 'test.pth.tar'))
    test_correct = 0
    for data, label in loader_test:
        # 使用模型计算结果
        output = net(data)
        # 取出10维数据中最大值的索引,即位预测结果
        _, output = output.max(dim=1)
        #计算预测正确率
        test_correct += (label == output).sum().item()
    print("存储的模型分类正确率:{}%".format(round(test_correct / 100.0, 2)))


if __name__ == '__main__':
    train_model(20, False)

模型结构

mnist手写数字识别,dnn实现代码解读_第1张图片

相关问题

net.train() 和 net.eval()的作用?

参考:net.train() 和 net.eval()的作用

为什么是output.max(1)

参考:output.max(1)

optim.zero_grad()、pred=model(input)、loss=criterion(pred,tgt)、loss.backward()、optim.step()的作用

参考:optim.zero_grad()、pred=model(input)、loss=criterion(pred,tgt)、loss.backward()、optim.step()的作用

你可能感兴趣的:(学习笔记,dnn,深度学习,python)