Pytorch实现MNIST字符识别

1.下载mnist.pkl.gz 

网址:http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz

数据集文件夹路径是data2/mnist/mnist.pkl.gz

 2.读取数据

from pathlib import Path

import matplotlib.pyplot as plt

DATA_PATH=Path("./data2")
PATH=DATA_PATH / "mnist"
FILENAME="mnist.pkl.gz"
import pickle
import gzip
with gzip.open((PATH/FILENAME).as_posix(),"rb") as f:
    ((x_train,y_train),(x_valid,y_valid),_)=pickle.load(f,encoding="latin-1")
#x_train(500,784),y_train(5000,) x_valid(10000, 784),y_valid(10000,)

随机显示一个数字

#==========28*28=784========随机显示数字5
import matplotlib.pyplot as plt
import numpy as np
plt.imshow(x_train[50].reshape((28,28)),cmap="gray")
plt.show()

 Pytorch实现MNIST字符识别_第1张图片

数据转为tensor

#=================数据转为tensor才能参与建模训练===
import torch
x_train,y_train,x_valid,y_valid=map(
    torch.tensor, (x_train,y_train,x_valid,y_valid)
)

3.设置损失函数为交叉熵函数

#=====torch.nn.functional==========
import torch.nn.functional as F
loss_func=F.cross_entropy

4.创建Model类

from torch import nn
class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1=nn.Linear(784,128)
        self.hidden2=nn.Linear(128,256)
        self.out=nn.Linear(256,10)
        self.dropout=nn.Dropout(0.5)
    def forward(self,x):
        x=F.relu(self.hidden1(x))
        #全连接层+droput,防止过拟合
        x=self.dropout(x)

        x=F.relu(self.hidden2(x))
        x=self.dropout(x)

        x=self.out(x)
        return x

# Mnist_NN(
#   (hidden1): Linear(in_features=784, out_features=128, bias=True)
#   (hidden2): Linear(in_features=128, out_features=256, bias=True)
#   (out): Linear(in_features=256, out_features=10, bias=True)
#   (dropout): Dropout(p=0.5, inplace=False)
# )
# net=Mnist_NN()
# print(net)

打印一下这网络长什么样

net=Mnist_NN()
print(net)

#打印定义好的名字和w和b
for name,parameter in net.named_parameters():
    print(name,parameter,parameter.size())
Mnist_NN(
  (hidden1): Linear(in_features=784, out_features=128, bias=True)
  (hidden2): Linear(in_features=128, out_features=256, bias=True)
  (out): Linear(in_features=256, out_features=10, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
hidden1.weight Parameter containing:
tensor([[-1.7000e-02, -7.5721e-03, -1.7358e-03,  ...,  7.6538e-03,
          7.2789e-03, -6.3162e-03],
        [ 1.4825e-02, -2.8308e-02,  1.4613e-02,  ...,  2.8675e-02,
         -2.7759e-02, -3.0119e-03],
        [ 3.0426e-02,  2.0748e-02, -3.4948e-03,  ...,  3.0710e-02,
         -4.2255e-03, -3.3025e-05],
        ...,
        [-3.4821e-02,  2.9946e-02, -3.1861e-02,  ...,  8.1853e-03,
         -1.9938e-02, -3.3388e-02],
        [-1.5154e-02,  1.3920e-02, -2.9203e-02,  ..., -2.0026e-02,
         -2.6470e-02, -9.5561e-03],
        [-6.9892e-03, -3.8977e-03, -2.6180e-02,  ..., -3.3327e-02,
         -1.4922e-02,  2.6092e-02]], requires_grad=True) torch.Size([128, 784])
hidden1.bias Parameter containing:
tensor([ 2.5751e-02, -1.0377e-02, -1.3521e-02, -4.4816e-03, -2.6966e-02,
        -3.4603e-02,  3.0604e-02, -3.3966e-02,  1.1005e-02, -3.7168e-03,
        -1.4667e-02, -1.7566e-02,  8.6428e-03,  2.4892e-02, -6.2577e-03,
         2.6498e-02,  1.8615e-02,  6.7457e-04, -2.4671e-02,  1.6848e-02,
         6.7135e-03,  2.7673e-02,  3.1954e-02,  2.2302e-03,  2.4956e-02,
        -2.1396e-02,  2.9394e-02,  2.7505e-02, -1.3726e-02,  2.5841e-02,
         1.4209e-02, -3.2747e-02, -3.1744e-02,  2.6159e-05,  2.4203e-02,
        -2.2114e-02,  1.3879e-02,  2.4030e-02, -7.9168e-03, -1.8550e-02,
         9.8353e-03, -2.1257e-03,  1.7249e-02,  1.8158e-02,  1.5701e-02,
         1.6851e-02, -2.2942e-02,  2.3169e-02,  9.7022e-03, -2.1426e-02,
         2.8050e-02,  2.2969e-02, -2.0299e-04, -3.2873e-02, -2.4000e-02,
         1.9591e-03,  1.9635e-02,  1.6035e-02, -2.4186e-02, -1.8391e-02,
         1.3318e-02, -2.1707e-02,  2.6806e-02, -1.9598e-02, -2.2144e-03,
         1.7189e-02, -1.2346e-02,  1.1579e-02,  2.4108e-02, -2.6332e-02,
         2.8129e-02, -1.8886e-02,  2.9338e-02, -2.1488e-02, -8.8897e-03,
         9.8325e-03, -1.8350e-02, -9.6284e-03,  2.0150e-03,  3.4695e-02,
         1.9319e-02, -1.2755e-02,  1.7373e-02, -2.5692e-02,  1.8800e-02,
         1.3602e-02, -2.5734e-02, -2.6812e-02, -2.8780e-02, -2.3088e-02,
        -1.7497e-02,  1.5619e-02, -7.5968e-03,  4.3703e-03, -3.9508e-03,
        -4.1538e-03, -2.2634e-02,  1.5959e-02,  2.6492e-02, -3.2042e-02,
         2.9527e-02,  2.5412e-02,  2.3659e-02,  9.8108e-03, -2.7806e-02,
         2.0003e-02, -1.2184e-02, -2.9596e-03,  8.7161e-03, -1.5629e-02,
         1.2305e-03,  2.0794e-02,  2.7004e-03,  2.8585e-02,  2.5467e-02,
         2.1072e-02,  7.3987e-03,  2.4267e-02, -2.6038e-03,  2.7078e-02,
         1.9574e-02, -1.9452e-02,  2.6705e-02, -3.5053e-02, -2.1770e-02,
        -1.0504e-02, -2.6949e-02, -1.3683e-02], requires_grad=True) torch.Size([128])
hidden2.weight Parameter containing:
tensor([[-0.0092, -0.0385,  0.0392,  ...,  0.0202, -0.0780, -0.0181],
        [-0.0694,  0.0047, -0.0097,  ...,  0.0494, -0.0311,  0.0568],
        [-0.0330,  0.0872,  0.0196,  ...,  0.0530, -0.0158, -0.0122],
        ...,
        [ 0.0671, -0.0538, -0.0228,  ...,  0.0426, -0.0721, -0.0875],
        [-0.0818,  0.0811,  0.0181,  ..., -0.0633,  0.0213,  0.0622],
        [ 0.0037, -0.0460, -0.0352,  ...,  0.0280,  0.0473, -0.0177]],
       requires_grad=True) torch.Size([256, 128])
hidden2.bias Parameter containing:
tensor([ 7.0455e-03, -8.5585e-02, -1.1738e-02,  5.0009e-02,  4.9168e-02,
         1.2480e-02, -3.5337e-02, -7.6582e-02, -6.2093e-02, -8.1431e-02,
        -8.3769e-02, -5.0416e-02, -4.2006e-03, -3.1598e-02, -2.1332e-02,
         3.0044e-02, -5.9326e-02, -5.5279e-03,  1.4245e-02,  5.6438e-02,
         4.9253e-02,  8.3465e-02,  5.6813e-02, -6.5503e-02,  3.8796e-02,
        -2.2497e-02,  1.6275e-02,  1.7218e-02, -5.2700e-02, -5.5127e-03,
        -7.7747e-02,  6.8675e-02,  1.2371e-02,  7.2901e-02, -3.3443e-02,
        -8.3422e-02,  4.2207e-02, -6.0726e-03, -6.8885e-02, -2.3315e-02,
         3.9463e-02,  8.1008e-02,  3.6421e-02,  1.2548e-02,  7.7355e-04,
        -8.3500e-02,  8.4208e-02,  1.0381e-02, -8.2894e-02,  3.3573e-02,
         7.9642e-02,  2.8607e-02, -1.0787e-02,  4.9583e-02, -5.1456e-02,
        -1.9139e-02, -4.2526e-02,  9.3131e-03,  9.7653e-03,  7.5836e-02,
         4.4982e-02,  3.6308e-02, -3.8912e-02, -1.0491e-02, -3.7225e-03,
         1.8632e-02, -2.3825e-02,  8.6090e-02,  2.8692e-02,  6.0389e-02,
         4.8401e-02, -8.3547e-02,  7.3226e-02, -8.6110e-02, -4.9497e-02,
         6.2549e-02,  4.9156e-02,  8.1692e-02,  7.9723e-02,  8.1305e-02,
        -3.7317e-02, -1.0242e-02, -5.4902e-03, -4.6117e-02, -5.9542e-02,
        -5.7721e-04,  8.4372e-02,  5.5357e-02, -7.3170e-02,  4.5636e-02,
        -1.2585e-02,  3.9466e-02,  3.7262e-02,  4.2761e-02, -4.0151e-02,
         5.8672e-02,  1.3433e-02,  3.2238e-02, -2.7873e-02,  8.7917e-02,
        -1.0597e-02,  2.2360e-02,  8.8234e-02, -2.3124e-02, -1.1411e-02,
        -3.5525e-02,  5.5158e-02, -5.2876e-03, -6.4926e-02,  8.4380e-02,
         4.8760e-02,  4.1597e-02, -4.3935e-02, -4.8510e-02, -1.8965e-02,
        -3.9994e-02,  4.1508e-03,  5.2132e-02,  2.6399e-02, -6.4724e-02,
         4.2482e-02, -3.6571e-03, -5.4799e-02,  6.1780e-02, -4.8077e-02,
        -6.9937e-02, -8.5853e-02,  5.5176e-02, -7.5692e-02,  3.4551e-02,
         3.7478e-02, -2.2803e-02,  5.2525e-02,  4.6488e-02, -7.9088e-02,
         7.8011e-02, -7.3475e-02, -7.6256e-03,  4.8434e-02,  6.1110e-02,
         2.1325e-03,  3.0351e-02, -5.0737e-02,  3.9619e-02,  5.5484e-02,
         7.0710e-02,  1.0315e-02,  3.6087e-02,  7.9133e-02,  2.4239e-02,
         1.8004e-02,  5.2432e-02,  5.0751e-02, -3.3260e-02, -8.2963e-02,
         3.5039e-03,  7.0203e-02, -2.2949e-02,  1.1312e-02,  4.5111e-02,
        -3.5280e-02,  3.2406e-02, -4.8534e-03, -3.3348e-02,  1.9767e-02,
        -1.9143e-02, -8.5507e-02, -7.4808e-02,  2.2461e-03, -6.1984e-02,
         1.0867e-02,  5.4872e-02, -1.4489e-02, -6.8439e-02, -6.5845e-02,
        -3.7658e-02, -8.1783e-02,  6.3587e-02,  5.5441e-02,  3.4191e-02,
         1.8248e-02,  5.9773e-02,  5.3208e-02, -5.2996e-02, -4.8426e-02,
         8.0495e-02, -8.7603e-02,  5.8796e-02, -5.0811e-02,  3.3368e-02,
         4.1760e-02,  3.3040e-02, -6.2694e-02,  2.6616e-02,  7.5526e-02,
         2.3045e-02,  8.7650e-02,  3.3344e-02, -2.2611e-02,  1.6431e-03,
         7.0857e-02, -4.5940e-02, -1.9087e-02, -4.0979e-02, -2.3439e-02,
         4.7452e-02, -6.7687e-02, -6.6874e-02, -2.2108e-02, -7.4597e-02,
        -8.4081e-02, -2.4046e-02,  5.2679e-02,  1.1143e-02,  5.7075e-02,
         5.5594e-02,  4.9338e-02,  6.3005e-02, -7.7738e-03, -7.2735e-02,
         1.9297e-06, -2.2267e-02, -5.2727e-02, -3.0659e-02,  6.5038e-02,
        -6.3450e-02, -4.7696e-02, -6.8763e-02,  3.7123e-02, -3.9326e-02,
         5.7811e-02, -7.1448e-02, -1.2273e-02,  1.9893e-02,  7.3995e-02,
        -5.3996e-02, -3.6854e-02,  4.2341e-02,  8.7182e-02, -3.4325e-02,
         3.2037e-02, -1.6926e-02, -1.9183e-02, -2.7851e-02,  1.4859e-02,
         2.3929e-02, -2.0850e-02, -4.4323e-02,  3.0600e-02,  6.4798e-02,
        -3.1656e-02,  2.6691e-02, -5.9549e-02, -5.6938e-02, -8.0130e-02,
        -1.9742e-02], requires_grad=True) torch.Size([256])
out.weight Parameter containing:
tensor([[-0.0173,  0.0522,  0.0494,  ..., -0.0579, -0.0439, -0.0522],
        [-0.0426,  0.0072, -0.0055,  ..., -0.0301,  0.0480, -0.0607],
        [-0.0454,  0.0552, -0.0465,  ..., -0.0398, -0.0137, -0.0034],
        ...,
        [ 0.0118, -0.0355,  0.0059,  ..., -0.0462, -0.0543,  0.0234],
        [-0.0365, -0.0555,  0.0247,  ...,  0.0361,  0.0263,  0.0357],
        [-0.0340,  0.0300,  0.0400,  ...,  0.0522, -0.0565,  0.0542]],
       requires_grad=True) torch.Size([10, 256])
out.bias Parameter containing:
tensor([-0.0154, -0.0028, -0.0574, -0.0608, -0.0276,  0.0483,  0.0503,  0.0112,
        -0.0352, -0.0498], requires_grad=True) torch.Size([10])

5.使用TensorDataset和DataLoader,封装成一个batch的数据集

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
bs=64
train_ds=TensorDataset(x_train,y_train)
# train_dl=DataLoader(train_ds,batch_size=bs,shuffle=True)

valid_ds=TensorDataset(x_valid,y_valid)
# valid_dl=DataLoader(valid_ds,batch_size=bs*2)

def get_data(train_ds,valid_ds,bs):
    return (
        DataLoader(train_ds,batch_size=bs,shuffle=True),
        DataLoader(valid_ds,batch_size=bs*2)
    )

6.定义训练步骤

import numpy as np
val_losses=[]
#steps:迭代次数,step相当于epoch
def fit(steps,model,loss_func,opt,train_dl,valid_dl):
    for step in range(steps):
        model.train() #更新w和b
        #xb(64,784) yb(64),xb和yb都是tensor
        for xb,yb in train_dl:
            loss_batch(model,loss_func,xb,yb,opt)
        #evaluate 模式,dropout和BatchNum不会工作
        model.eval() #不更新w和b
        with torch.no_grad():
            #losses:nums=(loss,batch),(loss,batch)....
            losses,nums =zip(
                *[loss_batch(model,loss_func,xb,yb) for xb,yb in valid_dl]
            )

        #总的验证集的平均损失
        val_loss=np.sum(np.multiply(losses,nums)) / np.sum(nums)
        val_losses.append(val_loss)
        print("当前step:"+str(step),"验证集损失"+str(val_loss))

from torch import optim
def get_model():
    model=Mnist_NN()
    #返回模型和优化器optim.SGD(model.parameters() , lr=0.001)
    return model,optim.Adam(model.parameters() , lr=0.001)


def loss_batch(model, loss_func ,xb,yb, opt=None):
    #根据预测值和真实值计算loss
    loss=loss_func( model(xb) , yb )

    if opt is not None:
        loss.backward()  #反向传播求梯度
        opt.step()       #更新参数
        opt.zero_grad()  #梯度清零,避免影响下一次的更新参数
    return loss.item(), len(xb)

7.开始训练模型

train_dl,valid_dl=get_data(train_ds,valid_ds,bs)
model,opt=get_model()
fit(20,model ,loss_func,opt,train_dl,valid_dl)


correct=0
total=0
#xb(128,784) , yb(128)
for xb,yb in valid_dl:
    #output(128,10),每一批128个样例,10个概率
    output=(model(xb))
    # print(output.shape)
    # print(output)
    #predicted==预测概率中最大的值的索引
    _,predicted=torch.max(output.data,1) #最大的值和索引
    # print(predicted)
    #size(0)==64,item()脱离tensor
    total+=yb.size(0)
    correct+=(predicted==yb).sum().item()
print("Accuracy of network on the 10000 test image :%d %%" %(
    100*correct / total
))

plt.figure()
plt.title("loss during training")
plt.plot(np.arange(1,21,1),val_losses)
plt.show()
当前step:0 验证集损失0.19325110550522803
当前step:1 验证集损失0.13869898459613322
当前step:2 验证集损失0.11913147141262889
当前step:3 验证集损失0.10589157585203647
当前step:4 验证集损失0.09970801477096974
当前step:5 验证集损失0.09848284918610006
当前step:6 验证集损失0.08794679024070501
当前step:7 验证集损失0.08894123120522127
当前step:8 验证集损失0.0905570782547351
当前step:9 验证集损失0.0877237871955149
当前step:10 验证集损失0.08790379901565612
当前step:11 验证集损失0.08826288345884532
当前step:12 验证集损失0.08438722904250026
当前step:13 验证集损失0.08695273711904883
当前step:14 验证集损失0.08459821079988032
当前step:15 验证集损失0.08047270769253373
当前step:16 验证集损失0.0862937849830836
当前step:17 验证集损失0.08164657156261383
当前step:18 验证集损失0.08109720230847597
当前step:19 验证集损失0.08208743708985858
Accuracy of network on the 10000 test image :97 %

Pytorch实现MNIST字符识别_第2张图片

你可能感兴趣的:(机器学习,pytorch,人工智能,python)