Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图

Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图

  • 前言
  • 一、案例要求
  • 二、训练数据准备
    • 1.下载手写英文字母数据集
    • 2.构建自己的数据集
  • 三、AlexNet实现
    • 1.AlexNet简介
    • 2. AlexNet模型代码
    • 3.AlexNet训练代码
    • 4.AlexNet对任意手写字母图片的预测
  • 四、PyQt鼠标手绘字母界面
    • 1.PyQt界面代码
    • 2. 效果图
  • 五、总结
  • 参考文献

前言

很快就要到计算机应用技术的DDL了,大部分同学应该都确定了开发工具,在编写代码的路上了吧。这里就给还有没实现的或者缺少想法的同学提供一些指引。希望看过这个篇文章的你能有所收获。值得一提的是,借鉴固然是好的,但是不要无脑照搬,希望你能写出有自己风格特色的代码。

一、案例要求

目标:实现一个简单的卷积神经网络,并测试运行简单的手写英文字母识别。
实现要求:使用Tensorflow或者Pytorch,编写LeNet5,或更复杂的卷积神经网络代码,下载一个手写英文字母数据集并训练,最终能实现手写英文字母的识别。
初步分析:1、在TensorFlow和Pytorch的选择上,我选择了Pytorch。原因如下,TensorFlow是谷歌推出的深度学习框架,Pytorch则是Facebook推出的,前者先于后者,这也就导致实际生产中TensorFlow的使用居多,Pytorch则是在实验室科研方面的使用居多,并且前者是属于静态的,后者是动态的,静态就意味着TensorFlow需要实现定义TensorFlow专有的数据类型变量,这就导致使用方面比不上Pytorch的方便易上手学习成本低的优势。
2、题目说明用LeNet,而本人选择AlexNet,原因很简单,就是AlexNet的知名度和它带给深度学习领域的贡献最为突出深远。

二、训练数据准备

1.下载手写英文字母数据集

工欲善其事必先利其器,准备好数据集就是第一步,数据集的好坏也会直接影响训练结果。
The Chars74K dataset
这是网上能查到比较靠谱的数据集下载网站了,进入网站后是这样,往下滑到Download选择EnglishHnd.tgz就可以了。
Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图_第1张图片
在windows下要对tgz解压两次,然后就可以看到里面有52个大小写英文和10个数字的数据集了,即62个文件夹,一个文件夹下55张png图片。

2.构建自己的数据集

由于只需要训练识别英文字母,选择其中的Sample011到Sample062复制到任意目录下
Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图_第2张图片
这里我对数据集进行了规整,由于英文字母里有些字母大小写及其相近,比如S和s、Y和y、O和o、P和p,我就规整成到同一个文件夹里了,为了提高最后的识别精度。并且我通过脚本对文件夹重新命名,更具辨识度,如果需要这份数据集,可以到我的博客下载资源处白嫖。

三、AlexNet实现

1.AlexNet简介

Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图_第3张图片
AlexNet网络结构是Hinton和他的学生Alex Krizhevsky在2012年ImageNet挑战赛上使用的模型结构,刷新了Image分类的记录,从此深度学习在Image这块开始一次次超过当前最好结果,甚至可以打败人类。AlexNet可以说是具有历史意义的一个网络结构,在AlexNet之前,深度学习已经沉寂了很久。历史的转折在2012年到来,AlexNet在当年的ImageNet图像分类竞赛中,Top-5错误率是比上一年的冠军下降了十个百分点,而且远超当年的亚军。
AlexNet之所以能够成功,深度学习之所以能够重回历史舞台,原因在于如下几方面:
1.非线性激活函数:ReLU(Rectified Linear Unit,校正线性单元),即f(x)=max(0,x)
2.防止过拟合的方法:Dropout,Data augmentation
3.大数据训练:百万级ImageNet图像数据
4.其他:GPU实现,LRN归一化层的使用。

2. AlexNet模型代码

model.py

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(  # 打包
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后
            nn.ReLU(inplace=True),  # inplace 可以载入更大模型
            nn.MaxPool2d(kernel_size=3, stride=2),  # output[48, 27, 27] kernel_num为原论文一半
            nn.Conv2d(48, 128, kernel_size=5, padding=2),  # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),  # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),  # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),  # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            # 全链接
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)  # 展平或者view()
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')  # 何教授方法
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)  # 正态分布赋值
                nn.init.constant_(m.bias, 0)

3.AlexNet训练代码

train.py

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.optim as optim
from model import AlexNet
import os
import time

# device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# 数据转换
data_transform = {
     
    "train": transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪
                                 transforms.RandomHorizontalFlip(),  # 随机反转
                                 transforms.ToTensor(),  # 张量转换
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),  # 初始化
    "val": transforms.Compose([transforms.Resize((224, 224)),  # 必须是(224, 224)
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

data_root = os.getcwd()  # 获取数据集根目录
image_path = data_root + "/minDataset/"  # 训练数据集目录
train_dataset = datasets.ImageFolder(root=image_path,
                                     transform=data_transform["train"])

print("训练数据集大小:", train_dataset)

batch_size = 50
train_loader = DataLoader(train_dataset,  # 加载训练集数据
                          batch_size=batch_size, shuffle=True,
                          num_workers=0)

validate_dataset = datasets.ImageFolder(root=image_path,
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = DataLoader(validate_dataset,  # 加载验证集数据
                             batch_size=batch_size, shuffle=True,
                             num_workers=0)

test_data_iter = iter(validate_loader)
test_image, test_label = next(test_data_iter)

net = AlexNet(num_classes=41, init_weights=True)  # 设置分类数为41,初始化权重
net.to(device)
# 损失函数:这里用交叉熵
loss_function = nn.CrossEntropyLoss()
# 优化器 这里用Adam
optimizer = optim.Adam(net.parameters(), lr=0.0001)
# 训练参数保存路径
save_path = './AlexNet.pth' #只要是pth文件就可以
# 训练过程中最高准确率
best_acc = 0.0

# 开始进行训练和测试,训练一轮,测试一轮
for epoch in range(30): # 30可以替换成任意正整数
    # train
    net.train()  # 训练过程中,使用之前定义网络中的dropout
    running_loss = 0.0
    t1 = time.perf_counter()
    for step, data in enumerate(train_loader):
        images, labels = data  # 提取图像数据和标签
        optimizer.zero_grad()  # 清空之前的梯度信息
        outputs = net(images.to(device))  # 将数据存放到设备
        loss = loss_function(outputs, labels.to(device))  # 计算损失值
        loss.backward()  # 损失后向传播到每个神经元
        optimizer.step()  # 更新每个神经元的参数
        running_loss += loss.item()  # 累加损失
        # 打印训练进度
        rate = (step + 1) / len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter() - t1)

    # validate
    net.eval()  # 测试过程中不需要dropout,使用所有的神经元
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():  # 进行验证,并不进行梯度跟踪
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]  # 得到预测结果
            acc += (predict_y == val_labels.to(device)).sum().item()  # 累计预测准确率
        val_accurate = acc / val_num  # 求得测试集准确率
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)  # 保存模型
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / len(train_loader), val_accurate))
print('Finished Training')

4.AlexNet对任意手写字母图片的预测

predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import json


def get_predict():
    data_transform = transforms.Compose( # 数据转换模型
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    img = Image.open("test.png") # 加载图片,自定义的图片名称
    img = data_transform(img) # 图片转换为矩阵
    # 对数据维度进行扩充
    img = torch.unsqueeze(img, dim=0)
    # 创建模型
    model = AlexNet(num_classes=41)
    # 加载模型权重
    model_weight_path = "./AlexNet.pth" #与train.py里的文件名对应
    model.load_state_dict(torch.load(model_weight_path))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img)) # 图片压缩
        predict = torch.softmax(output, dim=0) # 求softmax值
        predict_cla = torch.argmax(predict).numpy() # 预测分类结果
        with open("index.json","r")as f:
            data = json.load(f)
        print("预测结果为",data[str(predict_cla)])
        return data[str(predict_cla)]

index.json
由于预测的输出结果都是数字编号,所以事先做一张对照哈希表,可以知道输出的结果到底是哪个英文字母。

{"0": "a", "1": "b", "2": "c", "3": "d", "4": "e", "5": "f", "6": "g", "7": "h", "8": "i", "9": "j", "10": "k", "11": "l", "12": "m", "13": "n", "14": "o", "15": "p", "16": "q", "17": "r", "18": "s", "19": "t", "20": "u", "21": "v", "22": "w", "23": "x", "24": "y", "25": "z", "26": "A", "27": "B", "28": "D", "29": "E", "30": "F", "31": "G", "32": "H", "33": "I", "34": "J", "35": "L", "36": "M", "37": "N", "38": "Q", "39": "R", "40": "T"}

四、PyQt鼠标手绘字母界面

1.PyQt界面代码

这部分就不介绍了,没有过多技术可言,复制粘贴即可,然后自行修改设计。
ui.py

from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from PyQt5.QtWidgets import QApplication
from PyQt5.Qt import QPainter, QPoint, QPen
from PyQt5.QtCore import Qt
from PyQt5.Qt import QWidget, QColor, QPixmap, QIcon, QSize, QCheckBox
from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QPushButton,QComboBox, QLabel, QSpinBox
from predict import get_predict
import sys


def main():
    app = QApplication(sys.argv)
    mainWidget = MainWidget()  # 新建一个主界面
    mainWidget.show()  # 显示主界面
    exit(app.exec_())  # 进入消息循环


class PaintBoard(QWidget):
    def __init__(self, Parent=None):
        super().__init__(Parent)
        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()
        self.setWindowTitle("画笔")

    def __InitData(self):
        self.__size = QSize(480, 460)

        # 新建QPixmap作为画板,尺寸为__size
        self.__board = QPixmap(self.__size)
        self.__board.fill(Qt.white)  # 用白色填充画板

        self.__IsEmpty = True  # 默认为空画板
        self.EraserMode = False  # 默认为禁用橡皮擦模式

        self.__lastPos = QPoint(0, 0)  # 上一次鼠标位置
        self.__currentPos = QPoint(0, 0)  # 当前的鼠标位置

        self.__painter = QPainter()  # 新建绘图工具

        self.__thickness = 24  # 默认画笔粗细为10px
        self.__penColor = QColor("black")  # 设置默认画笔颜色为黑色
        self.__colorList = QColor.colorNames()  # 获取颜色列表

    def __InitView(self):
        # 设置界面的尺寸为__size
        self.setFixedSize(self.__size)

    def Clear(self):
        # 清空画板
        self.__board.fill(Qt.white)
        self.update()
        self.__IsEmpty = True

    def ChangePenColor(self, color="black"):
        # 改变画笔颜色
        self.__penColor = QColor(color)

    def ChangePenThickness(self, thickness=10):
        # 改变画笔粗细
        self.__thickness = thickness

    def IsEmpty(self):
        # 返回画板是否为空
        return self.__IsEmpty

    def GetContentAsQImage(self):
        # 获取画板内容(返回QImage)
        image = self.__board.toImage()
        return image

    def paintEvent(self, paintEvent):
        # 绘图事件
        # 绘图时必须使用QPainter的实例,此处为__painter
        # 绘图在begin()函数与end()函数间进行
        # begin(param)的参数要指定绘图设备,即把图画在哪里
        # drawPixmap用于绘制QPixmap类型的对象
        self.__painter.begin(self)
        # 0,0为绘图的左上角起点的坐标,__board即要绘制的图
        self.__painter.drawPixmap(0, 0, self.__board)
        self.__painter.end()

    def mousePressEvent(self, mouseEvent):
        # 鼠标按下时,获取鼠标的当前位置保存为上一次位置
        self.__currentPos = mouseEvent.pos()
        self.__lastPos = self.__currentPos

    def mouseMoveEvent(self, mouseEvent):
        # 鼠标移动时,更新当前位置,并在上一个位置和当前位置间画线
        self.__currentPos = mouseEvent.pos()
        self.__painter.begin(self.__board)

        if self.EraserMode == False:
            # 非橡皮擦模式
            self.__painter.setPen(QPen(self.__penColor, self.__thickness))  # 设置画笔颜色,粗细
        else:
            # 橡皮擦模式下画笔为纯白色,粗细为10
            self.__painter.setPen(QPen(Qt.white, 10))

        # 画线
        self.__painter.drawLine(self.__lastPos, self.__currentPos)
        self.__painter.end()
        self.__lastPos = self.__currentPos

        self.update()  # 更新显示

    def mouseReleaseEvent(self, mouseEvent):
        self.__IsEmpty = False  # 画板不再为空


class MainWidget(QWidget):
    def __init__(self, Parent=None):
        super().__init__(Parent)
        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()

    def __InitData(self):
        """
        初始化成员变量
        """
        self.__paintBoard = PaintBoard(self)
        # 获取颜色列表(字符串类型)
        self.__colorList = QColor.colorNames()

    def __InitView(self):
        """
        初始化界面
        """
        self.setFixedSize(700, 480)
        self.setWindowTitle("AlexNet手写字母识别器")

        # 新建一个水平布局作为本窗体的主布局
        main_layout = QHBoxLayout(self)
        # 设置主布局内边距以及控件间距为10px
        main_layout.setSpacing(10)

        # 在主界面左侧放置画板
        main_layout.addWidget(self.__paintBoard)

        # 新建垂直子布局用于放置按键
        sub_layout = QVBoxLayout()

        # 设置此子布局和内部控件的间距为10px
        sub_layout.setContentsMargins(10, 10, 10, 10)

        self.__btn_Clear = QPushButton("清空画板")
        self.__btn_Clear.setParent(self)  # 设置父对象为本界面

        # 将按键按下信号与画板清空函数相关联
        self.__btn_Clear.clicked.connect(self.__paintBoard.Clear)
        sub_layout.addWidget(self.__btn_Clear)

        self.__btn_Quit = QPushButton("退出")
        self.__btn_Quit.setParent(self)  # 设置父对象为本界面
        self.__btn_Quit.clicked.connect(self.Quit)
        sub_layout.addWidget(self.__btn_Quit)

        self.__cbtn_Eraser = QCheckBox("  使用橡皮擦")
        self.__cbtn_Eraser.setParent(self)
        self.__cbtn_Eraser.clicked.connect(self.on_cbtn_Eraser_clicked)
        sub_layout.addWidget(self.__cbtn_Eraser)

        self.__label_penThickness = QLabel(self)
        self.__label_penThickness.setText("画笔粗细")
        self.__label_penThickness.setFixedHeight(20)
        sub_layout.addWidget(self.__label_penThickness)

        self.__spinBox_penThickness = QSpinBox(self)
        self.__spinBox_penThickness.setMaximum(24)
        self.__spinBox_penThickness.setMinimum(20)
        self.__spinBox_penThickness.setValue(24)  # 默认粗细为10
        self.__spinBox_penThickness.setSingleStep(2)  # 最小变化值为2
        self.__spinBox_penThickness.valueChanged.connect(
            self.on_PenThicknessChange)  # 关联spinBox值变化信号和函数on_PenThicknessChange
        sub_layout.addWidget(self.__spinBox_penThickness)

        self.__label_penColor = QLabel(self)
        self.__label_penColor.setText("画笔颜色")
        self.__label_penColor.setFixedHeight(20)
        sub_layout.addWidget(self.__label_penColor)

        self.__comboBox_penColor = QComboBox(self)
        self.__fillColorList(self.__comboBox_penColor)  # 用各种颜色填充下拉列表
        self.__comboBox_penColor.currentIndexChanged.connect(
            self.on_PenColorChange)  # 关联下拉列表的当前索引变更信号与函数on_PenColorChange
        sub_layout.addWidget(self.__comboBox_penColor)

        self.__btn_Save = QPushButton("AlexNet预测")
        self.__btn_Save.setParent(self)
        self.__btn_Save.clicked.connect(self.on_btn_Save_Clicked)
        sub_layout.addWidget(self.__btn_Save)

        self.__textbox = QLineEdit(self)
        # self.__textbox.move(20, 20)
        # self.__textbox.resize(10,10)
        self.__textbox.setReadOnly(True)
        sub_layout.addWidget(self.__textbox)
        main_layout.addLayout(sub_layout)  # 将子布局加入主布局

    def __fillColorList(self, comboBox):
        index_black = 0
        index = 0
        for color in self.__colorList:
            if color == "black":
                index_black = index
            index += 1
            pix = QPixmap(70, 20)
            pix.fill(QColor(color))
            comboBox.addItem(QIcon(pix), None)
            comboBox.setIconSize(QSize(70, 20))
            comboBox.setSizeAdjustPolicy(QComboBox.AdjustToContents)

        comboBox.setCurrentIndex(index_black)

    def on_PenColorChange(self):
        color_index = self.__comboBox_penColor.currentIndex()
        color_str = self.__colorList[color_index]
        self.__paintBoard.ChangePenColor(color_str)

    def on_PenThicknessChange(self):
        penThickness = self.__spinBox_penThickness.value()
        self.__paintBoard.ChangePenThickness(penThickness)

    def on_btn_Save_Clicked(self): # 按钮点击事件触发
        image = self.__paintBoard.GetContentAsQImage()
        image.save('test.png')  # 默认保存为test.png文件
        ans = get_predict() # 调用函数进行预测
        self.__textbox.setText("预测结果为:" + ans)

    def on_cbtn_Eraser_clicked(self):
        if self.__cbtn_Eraser.isChecked():
            self.__paintBoard.EraserMode = True  # 进入橡皮擦模式
        else:
            self.__paintBoard.EraserMode = False  # 退出橡皮擦模式

    def Quit(self):
        self.close()


if __name__ == '__main__':
    main()

2. 效果图

项目结构图:
Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图_第4张图片
运行效果图
Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图_第5张图片

五、总结

整个项目只要掌握了自定义训练数据集和数据集的加载原理,其他都是模板式的代码,随便换个数据集就能变成另外一个深度学习应用实例。必要一提的是model的参数设置是根据你要得到的分类个数num_classes决定的,而且在预测加载模型权重时即pth文件,要保持分类的个数和训练时的分类个数一致,否则网络模型是运行不起来的。
欢迎留言交流讨论,谢谢你耐心地读完这篇文章!

参考文献

[1] The Chars74K dataset
[2] python 鼠标绘图
[3] 使用pytorch搭建AlexNet并训练花分类数据集
[4] 实现pytorch实现AlexNet(CNN经典网络模型详解)
[5] 使用AlexNet进行手写数字识别:项目结构与代码
[6] AlexNet原论文

你可能感兴趣的:(人工智能,深度学习,神经网络,python,Pytorch)