pyqt5结合keras+tensorflow实现深度学习训练过程GUI界面编写

pyqt5结合keras+tensorflow实现深度学习训练过程GUI界面编写

  • 主要实现功能
    • 直接上代码
    • 重要代码
    • 出错解决
    • 最终效果
    • 未解决问题
    • 补充说明
    • 写在最后

主要实现功能

第一次写博客,主要是想记录下最近踩的坑。最近想做一个集深度学习训练过程与缺陷检测过程为一体的界面,但是中间遇到许多问题,其中解决耗时最长的问题就是如何将深度学习训练过程实时显示在GUI界面的Textbrowser上,实现Textbrowser作为控制台输出的功能。

直接上代码

这里放的是GUI运行核心代码,其他的代码我将上传到CSDN下载中,有需要的小伙伴可以去下载,地址:https://download.csdn.net/download/weixin_42532587/12352345

import ctypes
import win32con
import sys
from PyQt5.QtWidgets import QMainWindow, QApplication, QDialog, QFileDialog, QMessageBox
from PyQt5 import QtCore, QtGui
from PyQt5.QtCore import QThread, pyqtSignal
from mainwindow import Ui_MainWindow
from Model_training import Ui_Dialog
from Detection import Ui_Dialog1
import global_var as gl
from model_train import model_training
from model_prediction import predict, prediction
import pandas as pd

class EmittingStream(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)
    def write(self, text):
        self.textWritten.emit(str(text))
    def flush(self):  # real signature unknown; restored from __doc__
        """ flush(self) """
        pass


class MainUI(QMainWindow, Ui_MainWindow):
    def __init__(self):
        super(MainUI, self).__init__()
        self.setupUi(self)
        self.pushButton_Training.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Detection.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Tuichu.setStyleSheet('color:red')
        self.pushButton_Tuichu.clicked.connect(self.close)
        self.Exit.triggered.connect(self.close)
        # self.pushButton_Detection.clicked.connect()


class Training_Dialog(QDialog, Ui_Dialog):
    def __init__(self):
        super(Training_Dialog, self).__init__()
        self.setupUi(self)
        self.pushButton_training.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_validation.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Start.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Stop.setStyleSheet('background:rgb(255, 0, 0)')
        self.comboBox_BS.addItems(['2', '4', '8', '16', '32', '64', '128'])
        self.comboBox_EP.addItems(['1', '20', '50', '100'])
        self.comboBox_LR.addItems(['0.1', '0.01', '0.001', '0.0001', '0.00001'])
        self.radioButton_AlexNet.setChecked(True)
        self.pushButton_Start.setEnabled(False)
        self.pushButton_training.clicked.connect(self.openfile)
        self.pushButton_validation.clicked.connect(self.openfile1)
        sys.stdout = EmittingStream(textWritten=self.normalOutputWritten)
        sys.stder = EmittingStream(textWritten=self.normalOutputWritten)
        self.pushButton_Start.clicked.connect(self.run_training)
        self.pushButton_Stop.clicked.connect(self.stop_training)
        self.my_thread = MyThread()  # 实例化线程对象

    def hyper_para(self):
        Epoch = self.comboBox_EP.currentText()
        gl.Epoch = int(Epoch)
        print('迭代次数为 %d' % gl.Epoch)
        batch_size = self.comboBox_BS.currentText()
        gl.batch_size = int(batch_size)
        print('批量尺寸为 %d' % gl.batch_size)
        Learning_rate = self.comboBox_LR.currentText()
        gl.learning_rate = float(Learning_rate)
        print('学习率为 %f' % gl.learning_rate)

    def stop_training(self):
        self.my_thread.is_on = False
        ret = ctypes.windll.kernel32.TerminateThread(  # @UndefinedVariable
            self.my_thread.handle, 0)
        print('终止训练', self.my_thread.handle, ret)

    def openfile(self):
        directory = QFileDialog.getExistingDirectory(self, "请选择文件夹路径",
                                                     "F:/Deep_CarrotNet/Carrot_resize_clear_split_clear")
        gl.gl_str_i = directory
        if len(gl.gl_str_i1) == 0:
            QMessageBox.critical(self, '提示', '请选择正确文件夹')

        print('成功加载训练文件', '训练文件夹所在位置:%s' % gl.gl_str_i)

    def openfile1(self):
        directory = QFileDialog.getExistingDirectory(self, "请选择文件夹路径",
                                                     "F:/Deep_CarrotNet/Carrot_resize_clear_split_clear")
        gl.gl_str_i1 = directory
        if len(gl.gl_str_i1) == 0:
            QMessageBox.critical(self, '提示', '请选择正确文件夹')
        else:
            self.pushButton_Start.setEnabled(True)
        print('成功加载验证文件', '验证文件夹所在位置:%s' % gl.gl_str_i1)

    def i_count(self):
        if self.radioButton_CarrotNet.text() == 'CarrotNet':
            if self.radioButton_CarrotNet.isChecked() == True:
                gl.gl_int_i = 2
                print('model is CarrotNet')

            elif self.radioButton_AlexNet.text() == 'AlexNet':
                if self.radioButton_AlexNet.isChecked() == True:
                    gl.gl_int_i = 1
                    print('model is AlexNet')

    def run_training(self):
        self.pushButton_Start.setEnabled(False)
        self.textBrowser.clear()
        self.i_count()
        self.hyper_para()
        if gl.gl_str_i == 'one':
            QMessageBox.critical(self, '错误', '请加载训练图片')
            self.my_thread.is_on = False
        elif gl.gl_str_i1 == 'one':
            QMessageBox.critical(self, '错误', '请加载验证图片')
            self.my_thread.is_on = False
        else:
            self.my_thread.is_on = True
        self.my_thread.start()  # 启动线程
        self.pushButton_Start.setEnabled(True)

    def normalOutputWritten(self, text):
        """Append text to the QTextEdit."""
        # Maybe QTextEdit.append() works as well, but this is how I do it:
        cursor = self.textBrowser.textCursor()
        cursor.movePosition(QtGui.QTextCursor.End)
        cursor.insertText(text)
        self.textBrowser.setTextCursor(cursor)
        self.textBrowser.ensureCursorVisible()


class MyThread(QThread):  # 线程类
    # my_signal = pyqtSignal(str)  # 自定义信号对象。参数str就代表这个信号可以传一个字符串

    def __init__(self):
        super(MyThread, self).__init__()
        # self.count = 0
        self.is_on = True

    def run(self):  # 线程执行函数
        self.handle = ctypes.windll.kernel32.OpenThread(  # @UndefinedVariable
            win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId()))
        while self.is_on:
            model_training(gl.gl_int_i, gl.gl_str_i, gl.gl_str_i1, gl.Epoch,
                           gl.batch_size, gl.learning_rate)
            self.is_on = False


class EmittingStream1(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)

    def write(self, text):
        self.textWritten.emit(str(text))

    def flush(self):  # real signature unknown; restored from __doc__
        """ flush(self) """
        pass


class Detection_Dialog(QDialog, Ui_Dialog1):
    def __init__(self):
        super(Detection_Dialog, self).__init__()
        self.setupUi(self)
        self.pushButton_start_detection.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_model.setStyleSheet('background:rgb(255, 0, 0)')
        self.pushButton_picture.setStyleSheet('background:rgb(255, 0, 0)')
        self.radioButton.setChecked(True)
        self.pushButton_model.setEnabled(True)
        self.pushButton_picture.setEnabled(False)
        self.pushButton_start_detection.setEnabled(False)
        self.pushButton_exit.setEnabled(False)
        self.pushButton_save.setEnabled(False)

        sys.stdout = EmittingStream1(textWritten=self.normalOutputWritten1)
        sys.stder = EmittingStream1(textWritten=self.normalOutputWritten1)
        # print('请先选择逐批检测还是逐个检测')
        self.pushButton_model.clicked.connect(self.message)
        self.pushButton_model.clicked.connect(self.load_moad)
        self.pushButton_picture.clicked.connect(self.load_image)
        self.my_thread1 = My_Thread1()  # 实例化线程对象
        self.pushButton_start_detection.clicked.connect(self.detection)
        self. pushButton_save.clicked.connect(self.save_result)
        self.pushButton_exit.clicked.connect(self.close)

    def save_result(self):
        path = QFileDialog.getExistingDirectory(self, "请选择文件路径")
        data = pd.DataFrame(gl.Y)
        data.to_csv(path + '/' + 'detection_result.csv', index=True)


    def message(self):
        QMessageBox.question(self, '提示', '请先选择逐批检测还是逐个检测')
        self.pushButton_model.setEnabled(True)
        self.pushButton_picture.setEnabled(True)
        self.pushButton_start_detection.setEnabled(True)
        self.pushButton_exit.setEnabled(True)

    def load_image(self):
        if self.radioButton.text() == '逐批检测':
            if self.radioButton.isChecked() == True:
                directory1 = QFileDialog.getExistingDirectory(self, "请选择文件路径")
                gl.gl_str_i3 = directory1
                print('成功导入检测文件', '检测文件所在位置:%s' % gl.gl_str_i3)
            elif self.radioButton_2.text() == '逐个检测':
                if self.radioButton_2.isChecked() == True:
                    fname, _ = QFileDialog.getOpenFileName(self, '选择图片', 'c:\\', 'Image files(*.jpg *.gif *.png)')
                    gl.gl_str_i4 = fname
                    print('成功导入检测图片', '检测文件所在位置:%s' % gl.gl_str_i4)
                else:
                    print('请正确选择检测文件路径')


    def load_moad(self):
        directory = QFileDialog.getExistingDirectory(self, "请选择文件路径")
        gl.gl_str_i2 = directory
        print('成功加载模型', '模型所在位置:%s' % gl.gl_str_i2)

    def normalOutputWritten1(self, text):
        """Append text to the QTextEdit."""
        # Maybe QTextEdit.append() works as well, but this is how I do it:
        cursor1 = self.textBrowser1.textCursor()
        cursor1.movePosition(QtGui.QTextCursor.End)
        cursor1.insertText(text)
        self.textBrowser1.setTextCursor(cursor1)
        self.textBrowser1.ensureCursorVisible()

    def detection(self):
        self.pushButton_start_detection.setEnabled(False)
        if self.radioButton.text() == '逐批检测':
            if self.radioButton.isChecked() == True:
                gl.i = 0
            elif self.radioButton_2.text() == '逐个检测':
                if self.radioButton_2.isChecked() == True:
                    gl.i = 1

        self.my_thread1.start()  # 启动线程
        self.pushButton_start_detection.setEnabled(True)
        self.pushButton_save.setEnabled(True)


class My_Thread1(QThread):
    def __init__(self):
        super(My_Thread1, self).__init__()

    def run(self):  # 线程执行函数
        print('测试开始')
        if gl.i == 0:
            prediction(gl.gl_str_i2, gl.gl_str_i3)
        else:
            predict(gl.gl_str_i2, gl.gl_str_i4)


if __name__ == "__main__":
    app = QApplication(sys.argv)
    main = MainUI()
    Training = Training_Dialog()
    Detection = Detection_Dialog()
    main.pushButton_Training.clicked.connect(Training.show)
    main.pushButton_Detection.clicked.connect(Detection.show)
    main.pushButton_Tuichu.clicked.connect(Training.close)
    main.pushButton_Tuichu.clicked.connect(Detection.close)
    main.Exit.triggered.connect(Training.close)
    main.Exit.triggered.connect(Detection.close)
    main.show()
    sys.exit(app.exec_())

重要代码

这里是关于如何将深度学习训练过程实时显示到GUI的Textbrowser上

class EmittingStream(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)

    def write(self, text):
        self.textWritten.emit(str(text))

    def flush(self):  # real signature unknown; restored from __doc__
        """ flush(self) """
        pass

一定要加上flush函数的定义,之前在CSDN上找了很久,都没有这行,导致GUI界面上的Textbrowsers只能输出深度学习训练过程的第一行,不能实现实时刷新的功能,加上这个定义就可以完美解决

sys.stdout = EmittingStream(textWritten=self.normalOutputWritten)
sys.stder = EmittingStream(textWritten=self.normalOutputWritten)
class MyThread(QThread):  # 线程类
    # my_signal = pyqtSignal(str)  # 自定义信号对象。参数str就代表这个信号可以传一个字符串

    def __init__(self):
        super(MyThread, self).__init__()
        # self.count = 0
        self.is_on = True

    def run(self):  # 线程执行函数
        self.handle = ctypes.windll.kernel32.OpenThread(  # @UndefinedVariable
            win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId())) # 是为了后面结束进程使用
        while self.is_on:
            model_training(gl.gl_int_i, gl.gl_str_i, gl.gl_str_i1, gl.Epoch,
                           gl.batch_size, gl.learning_rate)
            self.is_on = False
def stop_training(self):
        self.my_thread.is_on = False
        ret = ctypes.windll.kernel32.TerminateThread(  # @UndefinedVariable
            self.my_thread.handle, 0)
        print('终止训练', self.my_thread.handle, ret)

self.handle = ctypes.windll.kernel32.OpenThread( # @UndefinedVariable
win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId())) # 是为了后面结束进程使用
def stop_training(self): 终止训练过程

出错解决

用前面的代码理论上是可以实现实时显示深度学习训练过程的,但我在刚开始使用时,总会出现== finished with exit code -1073740791 (0xC0000409)==,当把 sys.stdout = EmittingStream1(textWritten=self.normalOutputWritten1)
sys.stder = EmittingStream1(textWritten=self.normalOutputWritten1)这两行注释掉时程序可以正常运行,只不过内容没有输出到textbrowser上。在网上搜了一大圈,也没有发现适合我程序的,最后才发现是keras的版本和Tensorflow的版本不匹配造成的,但是之前不在GUI内运行不报错,在GUI框架下运行就会报错,==最终选择Keras2.2.5,tensorflow1.14.0 ==解决了问题,但是运行程序时会出现一大串警告,不过不影响最终结果

最终效果

pyqt5结合keras+tensorflow实现深度学习训练过程GUI界面编写_第1张图片

未解决问题

我这个有两个子界面,每个子界面都有一个Textbrowser,而且都想达到实时刷新的效果,但是当同时使用时会出现两个Textbrowser内容相互干扰的现象。哪位大神知道如何玩解决的话,还望不吝赐教

补充说明

本界面还使用了全局变量实现不同函数之间的互相传值,具体方法是先建个global_var.py文件,将需要传值的参数预先定义。此后各个文件import使用就行了

# coding=utf-8
# 在别的文件使用方法:
# import global_var_model as gl
#  gl.gl_int_i += 4,可以通过访问和修改gl.gl_int_i来实现python的全局变量,或者叫静态变量访问
# gl.gl_int_i
import numpy as np
gl_int_i = 1  # 这里的gl_int_i是最常用的用于标记的全局变量
gl_str_i = 'one'
gl_str_i1 = 'one'
gl_str_i2 = 'one'
gl_str_i3 = 'one'
gl_str_i4 = 'one'
batch_size = 1
Epoch = 1
learning_rate = 0.1
i = 0
Y = np.array([])

写在最后

第一次写博客,语言也不怎么精炼,文学功底不行,希望大家将就着看,整个GUI的全部代码我将在后续上传到CSDN上。当然这篇博客也借鉴了很多前人的经验,在此表示感谢

你可能感兴趣的:(tensorflow,pyqt,gui,深度学习,python)