AI手写输入法 - pytorch从入门到入道(二)

本章承接上一篇的手写数字识别,利用训练好的模型,结合pyqt画板,实现简易手写输入法,为"hello world"例子增添乐趣。

pyqt是开发图形界面的框架,可以百度查找相关资料了解安装及基础方法,我搭建的环境是pycharm+pyqt5+qtdesigner,配置好之后的界面长这样:

AI手写输入法 - pytorch从入门到入道(二)_第1张图片

在左边的项目中右键某个文件,也可以打开qt菜单

具体怎么画界面不展开了,直接看下代码:

  1 # coding: utf-8
  2 from PyQt5.QtWidgets import *
  3 from PyQt5.QtGui import *
  4 from PyQt5.QtCore import *
  5 import sys
  6 sys.path.append(r'../ml/torch')
  7 from digit_recog import Net
  8 import torch
  9 import os
 10 import numpy as np
 11 import matplotlib.pyplot as plt
 12 from PIL import Image
 13 
 14 
 15 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 16 net = Net().to(device)
 17 # 加载参数
 18 nn_state = torch.load(os.path.join('../ml/torch/model/', 'net.pth'))
 19 # 参数加载到指定模型
 20 net.load_state_dict(nn_state)
 21 net.eval()
 22 
 23 
 24 def predict(img):
 25     # 读取图片并重设尺寸
 26     image = Image.open(img).resize((28, 28))
 27     # 灰度图
 28     gray_image = image.convert('L')
 29     # plt.imshow(gray_image)
 30     # plt.show()
 31     # 图片数据处理
 32     im_data = np.array(gray_image)
 33     im_data = torch.from_numpy(im_data).float()
 34     im_data = im_data.view(1, 1, 28, 28)
 35     # 神经网络运算
 36     outputs = net(im_data)
 37     # 取最大预测值
 38     _, pred = torch.max(outputs, 1)
 39     return pred.item()
 40 
 41 
 42 class SimpleDrawingBoard(QWidget):
 43     win = ''
 44     wins = []
 45 
 46     @classmethod
 47     def showWin(cls):
 48         # 聚焦到已有窗口
 49         if not cls.win:
 50             cls.win = cls()
 51             cls.win.show()
 52         else:
 53             cls.win.activateWindow()
 54 
 55     def __init__(self, parent=None):
 56         super(SimpleDrawingBoard, self).__init__(parent)
 57 
 58         self.setWindowTitle(u"手写数字识别")
 59         self.setWindowFlags(Qt.WindowStaysOnTopHint)
 60         self.size = (400, 350)
 61         self.resize(*self.size)
 62         self.setWindowFlag(Qt.FramelessWindowHint)  # 隐藏边框
 63         # self.setWindowOpacity(0.9)  # 设置窗口透明度
 64         # self.setAttribute(Qt.WA_TranslucentBackground)  # 设置窗口背景透明
 65 
 66         self.canvasSize = (280, 350)
 67         self.sizeOffset = [a - b for a, b in zip(self.size, self.canvasSize)]
 68         self.canvas = QPixmap(*self.canvasSize)
 69         self.canvas.fill(Qt.black)
 70         self.tempCanvas = QPixmap()
 71         self.lastPoint = QPoint()
 72         self.endPoint = QPoint()
 73         self.isDrawing = False
 74         self.penSize = 15
 75 
 76         self.initUI()
 77 
 78     def initUI(self):
 79         self.penSizeLabel = QLabel(u'画笔粗细')
 80         self.penSizeSpinBox = QSpinBox()
 81         self.penSizeSpinBox.setValue(self.penSize)
 82         self.penSizeSpinBox.valueChanged.connect(self.penSizeSpinBox_valueChanged)
 83         self.penSizeSpinBox.setFixedWidth(80)
 84 
 85         self.clearButton = QPushButton(u'清空')
 86         self.clearButton.setFixedWidth(80)
 87         self.clearButton.clicked.connect(self.clearPainter)
 88 
 89         self.closeButton = QPushButton(u'关闭')
 90         self.closeButton.setFixedWidth(80)
 91         self.closeButton.clicked.connect(self.close)
 92 
 93         self.inputLabel = QLabel(self)
 94         self.inputLabel.setFixedSize(80, 200)
 95         self.inputLabel.setAutoFillBackground(True)
 96         self.inputLabel.setAlignment(Qt.AlignCenter)
 97         self.inputLabel.setStyleSheet('''QLabel{background:#F76677;border-radius:5px;font-size:60px;font-weight:bolder;}''')
 98 
 99         mainLayout = QVBoxLayout(self)
100 
101         toolbarLayout = QGridLayout()
102         # toolbarLayout.setSpacing(20)
103         toolbarLayout.addWidget(self.penSizeLabel, 0, 0, 1, 1)
104         toolbarLayout.addWidget(self.penSizeSpinBox, 1, 0, 1, 1)
105         toolbarLayout.addWidget(self.clearButton, 2, 0, 1, 1)
106         toolbarLayout.addWidget(self.closeButton, 3, 0, 1, 1)
107         toolbarLayout.addWidget(self.inputLabel, 4, 0, 1, 1)
108 
109         toolbarLayout.setAlignment(Qt.AlignLeft)
110 
111         mainLayout.addLayout(toolbarLayout)
112         mainLayout.addStretch(1)
113 
114     def penSizeSpinBox_valueChanged(self):
115         # 设置画笔粗细
116         self.penSize = self.penSizeSpinBox.value()
117 
118     def paintEvent(self, event):
119         pp = QPainter(self.canvas)
120         pen = QPen(QColor(255, 255, 255), self.penSize)
121         pp.setPen(pen)
122         if self.lastPoint != self.endPoint:
123             pp.drawLine(self.lastPoint - QPoint(*self.sizeOffset), self.endPoint - QPoint(*self.sizeOffset))
124         painter = QPainter(self)
125         painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas)
126         self.lastPoint = self.endPoint
127 
128     def clearPainter(self):
129         print('clear...')
130         self.canvas.fill(Qt.black)
131         painter = QPainter(self)
132         painter.drawPixmap(self.sizeOffset[0], self.sizeOffset[1], self.canvas)
133         self.lastPoint = self.endPoint
134         self.update()
135         self.inputLabel.clear()
136 
137     def mousePressEvent(self, event):
138         # 按下左键
139         if event.button() == Qt.LeftButton:
140             self.lastPoint = event.pos()
141             self.endPoint = self.lastPoint
142             self.isDrawing = True
143 
144     def mouseMoveEvent(self, event):
145         if self.isDrawing:
146             self.update()
147             self.endPoint = event.pos()
148 
149     def mouseReleaseEvent(self, event):
150         if event.button() == Qt.LeftButton:
151             self.isDrawing = False
152             self.endPoint = event.pos()
153             self.update()
154             self.canvas.toImage().save('input.png')
155             input = predict('input.png')
156             self.inputLabel.setText(str(input))
157             print('你输入的是{}'.format(input))
158 
159 
160 if __name__ == '__main__':
161     app = QApplication.instance()
162     if not app:
163         app = QApplication(sys.argv)
164     SimpleDrawingBoard.showWin()
165     app.exec_()

上面引入前一章训练好的模型,位于不同的文件夹内,需要加上这一行代码:

sys.path.append(r'../ml/torch')

看下运行效果:

AI手写输入法 - pytorch从入门到入道(二)_第2张图片

AI手写输入法 - pytorch从入门到入道(二)_第3张图片

上面写了两个数字,识别输出正确!

helloworld例子比较枯燥,通过动手参与与AI交互增强信心乐趣,信心是一步步建立起来的,而大的突破亦是如此,后面会持续围绕简单的例子,深入发掘AI的乐趣与应用场景。

 

 

 

你可能感兴趣的:(AI手写输入法 - pytorch从入门到入道(二))