CNN网络实现手写数字(MNIST)识别 代码分析

CNN网络实现手写数字(MNIST)识别 代码分析(自学用)

Github代码源文件
本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别

#导入需要的包
import numpy as np   //第三方库,用于进行科学计算
import torch 
from torch import nn
from PIL import Image  // Python Image Library,python第三方图像处理库
import matplotlib.pyplot as plt //python的绘图库 pyplot:matplotlib的绘图框架
import os //提供了丰富的方法来处理文件和目录
from torchvision import datasets, transforms,utils //提供很多数据集的下载,包括COCO,ImageNet,CIFCAR等

1. 准备数据

(1)数据集介绍
MNIST数据集包含60000个训练集和10000测试数据集。分为图片和标签,图片是28*28的像素矩阵,标签为0~9共10个数字。

transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize(mean=[0.5],std=[0.5])])
                              
//Compos把多种数据处理的方法集合在一起
//使用transforms进行Tensor格式转换,将灰度范围从0-255变换到0-1之间
//批标准化(Batch Normalization),其作用就是先将输入归一化到(0,1),再使用公式"(x-mean)/std",将每个元素分布到(-1,1)
train_data = datasets.MNIST(root = "./data/"//root为数据集存放的路
                           transform=transform, //transform指定数据集导入的时候需要进行的变换
                           train = True,    //train设置为true表明导入的是训练集合,否则是测试集合
                           download = True) //如果为true,请从互联网下载数据集,然后将其放在根目录中。 如果数据集已经下载,则不是再次下载。

test_data = datasets.MNIST(root="./data/",
                          transform = transform,
                          train = False)
//train_data 的个数:60000个训练样
//test_data 的个数:10000个训练样本 
//一个样本的格式为[data,label],第一个存放数据,第二个存放标签
                     
train_loader = torch.utils.data.DataLoader(train_data,batch_size=64,
                                         shuffle=True,num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data,batch_size=64,
                                         shuffle=True,num_workers=2)
//设置batch_size表示每次训练的样本数量 ,加载器中的基本单位是一个batch的数据 ,这里是64

//所以train_loader 的长度是60000/64 = 938 个batch,test_loader 的长度是10000/64= 157 个batch
                          
//shuffle 将序列的所有元素随机排序。
//num_workers 表示用多少个子进程加载数据

从二维数组生成一张图片

oneimg,label = train_data[0]
oneimg = oneimg.numpy().transpose(1,2,0) //numpy.transpose默认第一个方括号“[]”为 0轴 ,第二个方括号为 1轴...所以有着交换轴改变矩阵序列的作,(x=0,y=1,z=2),新的x是原来的y轴大小,新的y是原来的z轴大小,新的z是原来的x大小
std = [0.5]  //标准差
mean = [0.5] //平均值
oneimg = oneimg * std + mean
oneimg.resize(28,28)
plt.imshow(oneimg)
plt.show()

CNN网络实现手写数字(MNIST)识别 代码分析_第1张图片
从三维生成一张黑白图片

oneimg,label = train_data[0]
grid = utils.make_grid(oneimg) //make_grid的作用是将若干幅图像拼成一幅图像。在需要展示一批数据时很有用。
grid = grid.numpy().transpose(1,2,0) 
std = [0.5]
mean = [0.5]
grid = grid * std + mean
plt.imshow(grid)
plt.show(

CNN网络实现手写数字(MNIST)识别 代码分析_第2张图片
输出一个batch的图片和标签

images, lables = next(iter(train_loader))
//next()函数:不断返回迭代器下一个值
//iter()函数:把list,dict,str等可迭代的对象Iterable(可以用for循环的对象)转换为迭代器Iterator可以使用
img = utils.make_grid(images
img = img.numpy().transpose(1,2,0) 
std = [0.5]
mean = [0.5]
img = img * std + mean
for i in range(64):
   print(lables[i], end=" ")
   i += 1
   if i%8 is 0:
       print(end=

你可能感兴趣的:(pytorch,神经网络)