pytorch CNN 手写数字识别

一个被放弃的入门级的例子终于被我实现了,虽然还不太完美,但还是想记录下

1.预处理

  相比较从库里下载数据集(关键是经常失败,格式也看不懂),更喜欢直接拿图片,从网上找了半天,最后从CSDN上下载了一个,真的是良心啊,都分好类了,有需要的可以找我

  (1)图片大小,灰度,格式处理:虽然这里用不到,以后可能用到,所以还是写了

  (2)图片打标:个人想法,图片名称含有标签,训练检测的时候方便拿

代码

 1 from PIL import Image
 2 import glob
 3 import os
 4 
 5 
 6 def load_image():
 7     """
 8     图片预处理
 9     将图片大小强制处理为28x28
10     转换为png格式
11     """
12     width = length = 28
13     train_path = 'D:/AI/MR_AIStudy/MNIST/dataset/train/*'
14     test_path = 'D:/AI/MR_AIStudy/MNIST/dataset/test/*'
15     img_path = glob.glob(test_path)  # 图片读取路径
16     try:
17         for file in img_path:
18             path, ext = os.path.splitext(file)
19             # print(path, ext)
20             img = Image.open(file)
21             # out = img.resize((width, length), Image.ANTIALIAS)
22             out = img.convert('L')
23             file_name = '{}{}'.format(path, '.png')
24             print(file_name)
25             out.save(file_name, quality=100)
26             print('success')
27             # img = Image.open(file)
28             # out = img.resize((width, length), Image.ANTIALIAS)
29             # out = out.convert('L')
30             # file_name = '{}{}'.format(path, ext)
31             # out.save(file_name, quality=100)
32     except Exception as e:
33         print(e)
34     # 图片预处理,将图片缩放到30px30px
35     # img_path = glob.glob('D:/AI/MR_AIStudy/opencv4/images/*.png')  # 图片读取路径
36     # for file in img_path:
37         # name = os.path.join(path_save, file)
38         # im = Image.open(file)
39         # im.thumbnail((30, 30))
40         # print(im.size)
41         # im.save(name, 'png')
42         # img = Image.open(file)
43         # data = img.getdata()
44         # data = np.matrix(data)
45         # data = np.reshape(data, (30, 30))
46         # print(data.size)
47 
48 
49 def rename():
50     # 修改文件名称为  序号-标签.bmp (123-2.bmp)  另存到D:/AI/MR_AIStudy/MNIST/dataset/train目录下
51     for label in range(10):
52         print(label)
53         # path = 'D:/AI/MR_AIStudy/MNIST/dataset/trainimage/{}/*.bmp'.format(label)
54         path = 'D:/AI/MR_AIStudy/MNIST/dataset/testimage/{}/*.bmp'.format(label)
55         # path_save = 'D:/AI/MR_AIStudy/MNIST/dataset/train'
56         path_save = 'D:/AI/MR_AIStudy/MNIST/dataset/test'
57         print('path', path)
58         img_path = glob.glob(path)
59         try:
60             for index, file in enumerate(img_path):
61                 # index用来区分相同标签不同图片
62                 path, ext = os.path.splitext(file)
63                 # print(path, ext)
64                 img = Image.open(file)
65                 out = img.convert('L')
66                 file_name = '{}-{}{}'.format(index, label, ext)  # 修改文件名称,将其打标
67                 print(file_name)
68                 # out.save(file_name, quality=100)
69                 out.save(os.path.join(path_save, os.path.basename(file_name)))  # 文件存到指定路径
70                 # break
71                 # print('success')
72 
73         except Exception as e:
74             print(e)
75         # break
76 
77 
78 if __name__ == '__main__':
79     load_image()
80     # change_ext()
81     # rename()

2.卷积神经网络

  本来是有归一化,softmax,独热方法的,但是我加上后不好使(加上softmax后不收敛了),就手动实现了一下归一化和独热

代码

import torch
import torch.nn as nn
import torch.utils.data as Data
import glob
import os
import numpy as np
from PIL import Image
import datetime
from torchvision import transforms
import torch.nn.functional as F
# 6272=8x32x32

EPOCH = 1
BATCH_SIZE = 50


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.con1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
        )
        self.con2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            # 线性分类器
            nn.Linear(128*7*7, 128),  # 修改图片大小后要重新计算
            nn.ReLU(),
            nn.Linear(128, 10),
            # nn.Softmax(dim=1),
        )
        self.mls = nn.MSELoss()
        self.opt = torch.optim.Adam(params=self.parameters(), lr=1e-3)
        self.start = datetime.datetime.now()

    def forward(self, inputs):
        out = self.con1(inputs)
        out = self.con2(out)
        out = out.view(out.size(0), -1)  # 展开成一维
        out = self.fc(out)
        # out = F.log_softmax(out, dim=1)
        return out

    def train(self, x, y):
        out = self.forward(x)
        loss = self.mls(out, y)
        print('loss: ', loss)
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

    def test(self, x):
        out = self.forward(x)
        return out


class ParseImage(object):
    def __init__(self):
        self.transform1 = transforms.Compose([
            transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0] 归一化
            ]
        )

    def get_data(self, path):
        # load_image()
        # 将图片转为矩阵,标签进行独热编码
        x_data = []
        y_data = []
        img_path = glob.glob(path)  # 图片读取路径
        for file in img_path:
            one_hot = []
            img = Image.open(file)
            # img = self.transform1(img)
            # img = transforms.ToPILImage()(img)
            data = img.getdata()
            data = np.matrix(data)
            data = np.reshape(data, (28, 28))
            # ..手动归一化
            data = data/255
            x_data.append(data)
            name, ext = os.path.splitext(file)
            label = name.split('-')[1]
            print('label', label)
            for i in range(10):
                if str(i) == label:
                    one_hot.append(1)
                else:
                    one_hot.append(0)
            y_data.append(one_hot)
        # 先转为数组,在转为tensor
        x_data = np.array(x_data)
        y_data = np.array(y_data)
        x_data = torch.from_numpy(x_data).float()
        # 输入数据增加频道维度
        x_data = torch.unsqueeze(x_data, 1)
        y_data = torch.from_numpy(y_data).float()
        return x_data, y_data


if __name__ == '__main__':
    data = ParseImage()
    train_path = 'D:/AI/MR_AIStudy/MNIST/dataset/train/*.png'
    test_path = 'D:/AI/MR_AIStudy/MNIST/dataset/test/*.png'
    x_data, y_data = data.get_data(train_path)
    net = MyNet()
    # 批训练
    torch_dataset = Data.TensorDataset(x_data, y_data)
    loader = Data.DataLoader(
        dataset=torch_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
    )
    for epoch in range(EPOCH):
        for step, (batch_x, batch_y) in enumerate(loader):
            print(step)
            net.train(batch_x, batch_y)

    torch.save(net, 'net.pkl')  # 存储模型, 全部存储

    # 只测试的话加载模型即可
    model = torch.load('net.pkl')  # 恢复模型
    net = model

    test_x, test_y = data.get_data(test_path)
    predict = net.test(test_x)
    print(predict)
    end = datetime.datetime.now()
    print('耗时:{}s'.format(end-net.start))
# 预测结果
# tensor([[ 9.1531e-01, -2.5804e-02, 1.2001e-02, 8.3876e-03, -1.6330e-02, # -1.7501e-03, -1.0589e-02, 2.6951e-02, 2.1836e-02, -4.5546e-02], # [-6.4733e-02, 7.7697e-01, 2.2536e-02, 8.3758e-03, 4.2895e-02, # 1.1602e-02, -3.0644e-02, 2.2412e-02, 1.1579e-01, 3.2196e-02], # [ 2.6631e-02, -5.3223e-02, 7.9808e-01, 6.0601e-03, 2.2453e-02, # -3.9522e-02, 3.4775e-02, 1.5853e-02, -6.9575e-03, 1.7208e-02], # [-1.3861e-02, -1.8332e-02, 4.9981e-02, 9.6510e-01, -1.5838e-02, # 9.0347e-03, 1.9342e-02, -3.8044e-02, -5.7994e-03, 1.4480e-02], # [-2.0864e-03, -5.9021e-02, 6.5524e-02, -2.1486e-02, 1.0074e+00, # 9.3356e-03, 1.0758e-02, 6.6142e-02, 1.4841e-02, 2.2529e-03], # [-8.4950e-02, -2.4841e-02, -7.7684e-02, 1.6404e-01, 4.3458e-02, # 8.6580e-01, -3.5630e-02, 4.2452e-02, 7.0675e-02, 2.9663e-02], # [-5.4024e-02, -1.7111e-02, -3.7085e-03, 3.8194e-03, -3.0645e-02, # -4.4164e-02, 1.0109e+00, 4.4349e-03, 1.3218e-01, -2.2839e-02], # [-2.0932e-02, 6.4831e-03, -1.3301e-02, 2.8091e-02, -3.0815e-02, # -3.2140e-02, 5.2251e-03, 1.0215e+00, 3.2592e-02, 1.0505e-02], # [ 1.5922e-02, -3.9700e-02, 2.4425e-02, -1.7313e-04, -1.5997e-02, # -5.2336e-02, -7.7526e-04, -2.1901e-02, 9.7167e-01, 1.3339e-01], # [-1.9283e-02, 2.4373e-02, -7.5621e-02, 1.1338e-01, -5.7805e-02, # -5.2936e-03, 1.0090e-03, 2.2471e-02, -3.5736e-02, 1.1243e+00]], # grad_fn=) # 耗时:0:09:59.665343s

预测结果不是很美观,但是正确的  欧耶!

你可能感兴趣的:(pytorch CNN 手写数字识别)