利用PYQT5结合YOLOX搭建检测系统

今天给大家分享用pyqt5桌面小组件搭建一个检测系统,暂定为公共场合猫狗检测系统,检测算法为YOLOX。后续会更新YOLOv8+pyqt5教程

系统界面展示: 

该系统可以进行图片检测,实时检测,视频检测

首先创建一个.py文件复制下面代码:

from PIL import Image
import numpy as np
import time
import os
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtGui import *
import cv2
import sys
from PyQt5.QtWidgets import *
# from detect_qt5 import main_detect,my_lodelmodel
from demo import main


'''摄像头和视频实时检测界面'''


class Ui_MainWindow(QWidget):
    def __init__(self, parent=None):
        super(Ui_MainWindow, self).__init__(parent)

        # self.face_recong = face.Recognition()
        self.timer_camera1 = QtCore.QTimer()
        self.timer_camera2 = QtCore.QTimer()
        self.timer_camera3 = QtCore.QTimer()
        self.timer_camera4 = QtCore.QTimer()
        self.cap = cv2.VideoCapture()

        self.CAM_NUM = 0

        # self.slot_init()
        self.__flag_work = 0
        self.x = 0
        self.count = 0
        self.setWindowTitle("公共场合猫狗检测系统")
        self.setWindowIcon(QIcon(os.getcwd() + '\\data\\source_image\\Detective.ico'))
        self.setFixedSize(1600, 900)
        self.yolo=main()
        # self.my_model = my_lodelmodel()
        self.button_open_camera = QPushButton(self)
        self.button_open_camera.setText(u'打开摄像头')
        self.button_open_camera.setStyleSheet(''' 
                                     QPushButton
                                     {text-align : center;
                                     background-color : white;
                                     font: bold;
                                     border-color: gray;
                                     border-width: 2px;
                                     border-radius: 10px;
                                     padding: 6px;
                                     height : 14px;
                                     border-style: outset;
                                     font : 14px;}
                                     QPushButton:pressed
                                     {text-align : center;
                                     background-color : light gray;
                                     font: bold;
                                     border-color: gray;
                                     border-width: 2px;
                                     border-radius: 10px;
                                     padding: 6px;
                                     height : 14px;
                                     border-style: outset;
                                     font : 14px;}
                                     ''')
        self.button_open_camera.move(10, 40)
        self.button_open_camera.clicked.connect(self.button_open_camera_click)
        # self.button_open_camera.clicked.connect(self.button_open_camera_click1)
        # btn.clicked.connect(self.openimage)

        self.btn1 = QPushButton(self)
        self.btn1.setText("检测摄像头")
        self.btn1.setStyleSheet(''' 
                                             QPushButton
                                             {text-align : center;
                                             background-color : white;
                                             font: bold;
                                             border-color: gray;
                                             border-width: 2px;
                                             border-radius: 10px;
                                             padding: 6px;
                                             height : 14px;
                                             border-style: outset;
                                             font : 14px;}
                                             QPushButton:pressed
                                             {text-align : center;
                                             background-color : light gray;
                                             font: bold;
                                             border-color: gray;
                                             border-width: 2px;
                                             border-radius: 10px;
                                             padding: 6px;
                                             height : 14px;
                                             border-style: outset;
                                             font : 14px;}
                                             ''')
        self.btn1.move(10, 80)
        self.btn1.clicked.connect(self.button_open_camera_click1)
        # print("QPushButton构建")

        self.open_video = QPushButton(self)
        self.open_video.setText("打开视频")
        self.open_video.setStyleSheet(''' 
                                             QPushButton
                                             {text-align : center;
                                             background-color : white;
                                             font: bold;
                                             border-color: gray;
                                             border-width: 2px;
                                             border-radius: 10px;
                                             padding: 6px;
                                             height : 14px;
                                             border-style: outset;
                                             font : 14px;}
                                             QPushButton:pressed
                                             {text-align : center;
                                             background-color : light gray;
                                             font: bold;
                                             border-color: gray;
                                             border-width: 2px;
                                             border-radius: 10px;
                                             padding: 6px;
                                             height : 14px;
                                             border-style: outset;
                                             font : 14px;}
                                             ''')
        self.open_video.move(10, 160)
        self.open_video.clicked.connect(self.open_video_button)
        print("QPushButton构建")

        self.btn1 = QPushButton(self)
        self.btn1.setText("检测视频文件")
        self.btn1.setStyleSheet(''' 
                                             QPushButton
                                             {text-align : center;
                                             background-color : white;
                                             font: bold;
                                             border-color: gray;
                                             border-width: 2px;
                                             border-radius: 10px;
                                             padding: 6px;
                                             height : 14px;
                                             border-style: outset;
                                             font : 14px;}
                                             QPushButton:pressed
                                             {text-align : center;
                                             background-color : light gray;
                                             font: bold;
                                             border-color: gray;
                                             border-width: 2px;
                                             border-radius: 10px;
                                             padding: 6px;
                                             height : 14px;
                                             border-style: outset;
                                             font : 14px;}
                                             ''')
        self.btn1.move(10, 200)
        self.btn1.clicked.connect(self.detect_video)
        print("QPushButton构建")

        # btn1.clicked.connect(self.detect())
        # btn1.clicked.connect(self.button1_test)

        # btn1.clicked.connect(self.detect())
        # btn1.clicked.connect(self.button1_test)

        btn2 = QPushButton(self)
        btn2.setText("返回上一界面")
        btn2.setStyleSheet(''' 
                                             QPushButton
                                             {text-align : center;
                                             background-color : white;
                                             font: bold;
                                             border-color: gray;
                                             border-width: 2px;
                                             border-radius: 10px;
                                             padding: 6px;
                                             height : 14px;
                                             border-style: outset;
                                             font : 14px;}
                                             QPushButton:pressed
                                             {text-align : center;
                                             background-color : light gray;
                                             font: bold;
                                             border-color: gray;
                                             border-width: 2px;
                                             border-radius: 10px;
                                             padding: 6px;
                                             height : 14px;
                                             border-style: outset;
                                             font : 14px;}
                                             ''')
        btn2.move(10, 240)
        btn2.clicked.connect(self.back_lastui)

        # 信息显示
        self.label_show_camera = QLabel(self)
        self.label_move = QLabel()
        self.label_move.setFixedSize(100, 100)
        # self.label_move.setText(" 11  待检测图片")
        self.label_show_camera.setFixedSize(700, 500)
        self.label_show_camera.setAutoFillBackground(True)
        self.label_show_camera.move(110, 80)
        self.label_show_camera.setStyleSheet("QLabel{background:#F5F5DC;}"
                                             "QLabel{color:rgb(300,300,300,120);font-size:10px;font-weight:bold;font-family:宋体;}"
                                             )
        self.label_show_camera1 = QLabel(self)
        self.label_show_camera1.setFixedSize(700, 500)
        self.label_show_camera1.setAutoFillBackground(True)
        self.label_show_camera1.move(850, 80)
        self.label_show_camera1.setStyleSheet("QLabel{background:#F5F5DC;}"
                                              "QLabel{color:rgb(300,300,300,120);font-size:10px;font-weight:bold;font-family:宋体;}"
                                              )

        self.timer_camera1.timeout.connect(self.show_camera)
        self.timer_camera2.timeout.connect(self.show_camera1)
        # self.timer_camera3.timeout.connect(self.show_camera2)
        self.timer_camera4.timeout.connect(self.show_camera2)
        self.timer_camera4.timeout.connect(self.show_camera3)
        self.clicked = False

        # self.setWindowTitle(u'摄像头')
        self.frame_s = 3
        # 设置背景图片
        palette1 = QPalette()
        palette1.setBrush(self.backgroundRole(), QBrush(QPixmap('R-C.png')))
        self.setPalette(palette1)

    def back_lastui(self):
        self.timer_camera1.stop()
        self.cap.release()
        self.label_show_camera.clear()
        self.timer_camera2.stop()

        self.label_show_camera1.clear()
        cam_t.close()
        ui_p.show()

    '''摄像头'''

    def button_open_camera_click(self):
        if self.timer_camera1.isActive() == False:
            flag = self.cap.open(self.CAM_NUM)
            if flag == False:
                msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
                                                    buttons=QtWidgets.QMessageBox.Ok,
                                                    defaultButton=QtWidgets.QMessageBox.Ok)

            else:
                self.timer_camera1.start(30)

                self.button_open_camera.setText(u'关闭摄像头')
        else:
            self.timer_camera1.stop()
            self.cap.release()
            self.label_show_camera.clear()
            self.timer_camera2.stop()

            self.label_show_camera1.clear()
            self.button_open_camera.setText(u'打开摄像头')

    def show_camera(self):  # 摄像头左边
        flag, self.image = self.cap.read()

        dir_path = os.getcwd()
        camera_source = dir_path + "\\data\\test\\2.jpg"
        cv2.imwrite(camera_source, self.image)

        width = self.image.shape[1]
        height = self.image.shape[0]

        # 设置新的图片分辨率框架
        width_new = 700
        height_new = 500

        # 判断图片的长宽比率
        if width / height >= width_new / height_new:

            show = cv2.resize(self.image, (width_new, int(height * width_new / width)))
        else:

            show = cv2.resize(self.image, (int(width * height_new / height), height_new))

        show = cv2.cvtColor(show, cv2.COLOR_BGR2RGB)

        showImage = QtGui.QImage(show.data, show.shape[1], show.shape[0], 3 * show.shape[1], QtGui.QImage.Format_RGB888)

        self.label_show_camera.setPixmap(QtGui.QPixmap.fromImage(showImage))

    def button_open_camera_click1(self):
        if self.timer_camera2.isActive() == False:
            flag = self.cap.open(self.CAM_NUM)
            if flag == False:
                msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
                                                    buttons=QtWidgets.QMessageBox.Ok,
                                                    defaultButton=QtWidgets.QMessageBox.Ok)

            else:
                self.timer_camera2.start(30)
                self.button_open_camera.setText(u'关闭摄像头')
        else:
            self.timer_camera2.stop()
            self.cap.release()
            self.label_show_camera1.clear()
            self.button_open_camera.setText(u'打开摄像头')

    def show_camera1(self):
        fps = 0.0
        t1 = time.time()
        flag, self.image = self.cap.read()
        self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
        # self.image = Image.fromarray(np.uint8(self.image))
        im0, nums, ti = self.yolo.demoimg(self.image)
        im0= cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)
        width = im0.shape[1]
        height = im0.shape[0]

        # 设置新的图片分辨率框架
        width_new = 640
        height_new = 640

        # 判断图片的长宽比率
        if width / height >= width_new / height_new:

            show = cv2.resize(im0, (width_new, int(height * width_new / width)))
        else:

            show = cv2.resize(im0, (int(width * height_new / height), height_new))
        # im0 = cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
        if nums>= 1:
            fps = (fps + (1. / (time.time() - t1))) / 2
            im0 = cv2.putText(im0, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            im0 = cv2.putText(im0, "No pets allowed", (0, 150), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

        showImage = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)

        self.label_show_camera1.setPixmap(QtGui.QPixmap.fromImage(showImage))

    '''视频检测'''

    def open_video_button(self):

        if self.timer_camera4.isActive() == False:

            imgName, imgType = QFileDialog.getOpenFileName(self, "打开视频", "", "*.mp4;;*.AVI;;*.rmvb;;All Files(*)")

            self.cap_video = cv2.VideoCapture(imgName)

            flag = self.cap_video.isOpened()

            if flag == False:
                msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
                                                    buttons=QtWidgets.QMessageBox.Ok,
                                                    defaultButton=QtWidgets.QMessageBox.Ok)

            else:

                # self.timer_camera3.start(10)
                self.show_camera2()
                self.open_video.setText(u'关闭视频')
        else:
            # self.timer_camera3.stop()
            self.cap_video.release()
            self.label_show_camera.clear()
            self.timer_camera4.stop()
            self.frame_s = 3
            self.label_show_camera1.clear()
            self.open_video.setText(u'打开视频')

    def detect_video(self):

        if self.timer_camera4.isActive() == False:
            flag = self.cap_video.isOpened()
            if flag == False:
                msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
                                                    buttons=QtWidgets.QMessageBox.Ok,
                                                    defaultButton=QtWidgets.QMessageBox.Ok)

            else:
                self.timer_camera4.start(30)

        else:
            self.timer_camera4.stop()
            self.cap_video.release()
            self.label_show_camera1.clear()

    def show_camera2(self):  # 显示视频的左边

        # 抽帧
        length = int(self.cap_video.get(cv2.CAP_PROP_FRAME_COUNT))  # 抽帧
        print(self.frame_s, length)  # 抽帧
        flag, self.image1 = self.cap_video.read()  # image1是视频的
        if flag == True:

            width = self.image1.shape[1]
            height = self.image1.shape[0]

            # 设置新的图片分辨率框架
            width_new = 700
            height_new = 500

            # 判断图片的长宽比率
            if width / height >= width_new / height_new:

                show = cv2.resize(self.image1, (width_new, int(height * width_new / width)))
            else:

                show = cv2.resize(self.image1, (int(width * height_new / height), height_new))

            show = cv2.cvtColor(show, cv2.COLOR_BGR2RGB)

            showImage = QtGui.QImage(show.data, show.shape[1], show.shape[0], 3 * show.shape[1],
                                     QtGui.QImage.Format_RGB888)

            self.label_show_camera.setPixmap(QtGui.QPixmap.fromImage(showImage))
        else:
            self.cap_video.release()
            self.label_show_camera.clear()
            self.timer_camera4.stop()

            self.label_show_camera1.clear()
            self.open_video.setText(u'打开视频')

    def show_camera3(self):

        flag, self.image1 = self.cap_video.read()
        self.frame_s += 1
        if flag == True:
            # if self.frame_s % 3 == 0:   #抽帧
            # face = self.face_detect.align(self.image)
            # if face:
            #     pass

            # dir_path = os.getcwd()
            # camera_source = dir_path + "\\data\\test\\video.jpg"
            #
            # cv2.imwrite(camera_source, self.image1)
            # print("im01")
            # im0, label = main_detect(self.my_model, camera_source)
            im0,nums,ti = self.yolo.demoimg(self.image1)
            # print("imo",im0)
            # print(label)
            # if label == 'debug':
            #     print("labelkong")
            # print("debug")

            # im0, label = slef.detect()
            # print("debug1")
            width = im0.shape[1]
            height = im0.shape[0]

            # 设置新的图片分辨率框架
            width_new = 700
            height_new = 500

            # 判断图片的长宽比率
            if width / height >= width_new / height_new:

                show = cv2.resize(im0, (width_new, int(height * width_new / width)))
            else:

                show = cv2.resize(im0, (int(width * height_new / height), height_new))

            im0 = show#cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
            # print("debug2")
            if nums >= 1:
                im0 = cv2.putText(im0, "Warning", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                im0 = cv2.putText(im0, f"nums:{nums}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
            showImage = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)

            self.label_show_camera1.setPixmap(QtGui.QPixmap.fromImage(showImage))


'''单张图片检测'''


class picture(QWidget):

    def __init__(self):
        super(picture, self).__init__()

        self.str_name = '0'

        self.yolo = main()

        # self.my_model=my_lodelmodel()
        self.resize(1600, 900)
        self.setWindowIcon(QIcon(os.getcwd() + '\\data\\source_image\\Detective.ico'))
        self.setWindowTitle("公共场合猫狗检测系统")

        # window_pale = QtGui.QPalette()
        # window_pale.setBrush(self.backgroundRole(), QtGui.QBrush(
        #     QtGui.QPixmap(os.getcwd() + '\\data\\source_image\\backgroud.jpg')))
        # self.setPalette(window_pale)
        palette2 = QPalette()
        palette2.setBrush(self.backgroundRole(), QBrush(QPixmap('4.jpg')))
        self.setPalette(palette2)

        camera_or_video_save_path = 'data\\test'
        if not os.path.exists(camera_or_video_save_path):
            os.makedirs(camera_or_video_save_path)

        self.label1 = QLabel(self)
        self.label1.setText("   待检测图片")
        self.label1.setFixedSize(700, 500)
        self.label1.move(110, 80)

        self.label1.setStyleSheet("QLabel{background:#7A6969;}"
                                  "QLabel{color:rgb(300,300,300,120);font-size:20px;font-weight:bold;font-family:宋体;}"
                                  )
        self.label2 = QLabel(self)
        self.label2.setText("检测结果")
        self.label2.setFixedSize(700, 500)
        self.label2.move(850, 80)

        self.label2.setStyleSheet("QLabel{background:#7A6969;}"
                                  "QLabel{color:rgb(300,300,300,120);font-size:20px;font-weight:bold;font-family:宋体;}"
                                  )

        self.label3 = QLabel(self)
        self.label3.setText("")
        self.label3.move(1200, 620)
        self.label3.setStyleSheet("font-size:20px;")
        self.label3.adjustSize()

        btn = QPushButton(self)
        btn.setText("打开图片")
        btn.setStyleSheet(''' 
                                                     QPushButton
                                                     {text-align : center;
                                                     background-color : white;
                                                     font: bold;
                                                     border-color: gray;
                                                     border-width: 2px;
                                                     border-radius: 10px;
                                                     padding: 6px;
                                                     height : 14px;
                                                     border-style: outset;
                                                     font : 14px;}
                                                     QPushButton:pressed
                                                     {text-align : center;
                                                     background-color : light gray;
                                                     font: bold;
                                                     border-color: gray;
                                                     border-width: 2px;
                                                     border-radius: 10px;
                                                     padding: 6px;
                                                     height : 14px;
                                                     border-style: outset;
                                                     font : 14px;}
                                                     ''')
        btn.move(10, 30)
        btn.clicked.connect(self.openimage)

        btn1 = QPushButton(self)
        btn1.setText("检测图片")
        btn1.setStyleSheet(''' 
                                                     QPushButton
                                                     {text-align : center;
                                                     background-color : white;
                                                     font: bold;
                                                     border-color: gray;
                                                     border-width: 2px;
                                                     border-radius: 10px;
                                                     padding: 6px;
                                                     height : 14px;
                                                     border-style: outset;
                                                     font : 14px;}
                                                     QPushButton:pressed
                                                     {text-align : center;
                                                     background-color : light gray;
                                                     font: bold;
                                                     border-color: gray;
                                                     border-width: 2px;
                                                     border-radius: 10px;
                                                     padding: 6px;
                                                     height : 14px;
                                                     border-style: outset;
                                                     font : 14px;}
                                                     ''')
        btn1.move(10, 80)
        # print("QPushButton构建")
        btn1.clicked.connect(self.button1_test)

        btn3 = QPushButton(self)
        btn3.setText("视频和摄像头检测")
        btn3.setStyleSheet(''' 
                                                     QPushButton
                                                     {text-align : center;
                                                     background-color : white;
                                                     font: bold;
                                                     border-color: gray;
                                                     border-width: 2px;
                                                     border-radius: 10px;
                                                     padding: 6px;
                                                     height : 14px;
                                                     border-style: outset;
                                                     font : 14px;}
                                                     QPushButton:pressed
                                                     {text-align : center;
                                                     background-color : light gray;
                                                     font: bold;
                                                     border-color: gray;
                                                     border-width: 2px;
                                                     border-radius: 10px;
                                                     padding: 6px;
                                                     height : 14px;
                                                     border-style: outset;
                                                     font : 14px;}
                                                     ''')
        btn3.move(10, 160)
        btn3.clicked.connect(self.camera_find)

        self.imgname1 = '0'

    def camera_find(self):
        ui_p.close()
        cam_t.show()

    def openimage(self):

        imgName, imgType = QFileDialog.getOpenFileName(self, "打开图片", "D://",
                                                       "Image files (*.jpg *.gif *.png *.jpeg)")  # "*.jpg;;*.png;;All Files(*)"
        if imgName != '':
            self.imgname1 = imgName
            # print("imgName",imgName,type(imgName))
            self.im0 = cv2.imread(imgName)
            width = self.im0.shape[1]
            height = self.im0.shape[0]
            # 设置新的图片分辨率框架
            width_new = 700
            height_new = 500

            # 判断图片的长宽比率
            if width / height >= width_new / height_new:

                show = cv2.resize(self.im0, (width_new, int(height * width_new / width)))
            else:

                show = cv2.resize(self.im0, (int(width * height_new / height), height_new))
            im0 = cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
            showImage = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
            self.label1.setPixmap(QtGui.QPixmap.fromImage(showImage))

            # jpg = QtGui.QPixmap(imgName).scaled(self.label1.width(), self.label1.height())
            # self.label1.setPixmap(jpg)

    def button1_test(self):
        if self.imgname1 != '0':
            # QApplication.processEvents()
            # image = Image.open(self.imgname1)
            image = cv2.imread(self.imgname1)
            # K, im0 = self.yolo.detect_image(image)
            im0,nums,time=self.yolo.demoimg(image)
            print(nums)
            # im0 = np.array(im0)
            # QApplication.processEvents()
            width = im0.shape[1]
            height = im0.shape[0]

            # 设置新的图片分辨率框架
            width_new = 700
            height_new = 700

            # 判断图片的长宽比率
            if width / height >= width_new / height_new:

                im0 = cv2.resize(im0, (width_new, int(height * width_new / width)))
            else:

                im0 = cv2.resize(im0, (int(width * height_new / height), height_new))
            im0 = cv2.putText(im0, f"Infertime:{round(time,2)}s", (410, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
            # im0 = cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
            if nums >= 1:
                     im0 = cv2.putText(im0, "Warning", (410, 20), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
                     im0 = cv2.putText(im0, f"nums:{nums}", (410, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
            image_name = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
            # label=label.split(' ')[0]    #label 59 0.96   分割字符串  取前一个
            self.label2.setPixmap(QtGui.QPixmap.fromImage(image_name))
            # jpg = QtGui.QPixmap(image_name).scaled(self.label1.width(), self.label1.height())
            # self.label2.setPixmap(jpg)
        else:
            QMessageBox.information(self, '错误', '请先选择一个图片文件', QMessageBox.Yes, QMessageBox.Yes)


if __name__ == '__main__':
    app = QApplication(sys.argv)
    splash = QSplashScreen(QPixmap(".\\data\\source_image\\logo.png"))
    # 设置画面中的文字的字体
    splash.setFont(QFont('Microsoft YaHei UI', 12))
    # 显示画面
    splash.show()
    # 显示信息
    splash.showMessage("程序初始化中... 0%", QtCore.Qt.AlignLeft | QtCore.Qt.AlignBottom, QtCore.Qt.black)
    time.sleep(0.3)

    splash.showMessage("正在加载模型配置文件...60%", QtCore.Qt.AlignLeft | QtCore.Qt.AlignBottom, QtCore.Qt.black)
    cam_t = Ui_MainWindow()
    splash.showMessage("正在加载模型配置文件...100%", QtCore.Qt.AlignLeft | QtCore.Qt.AlignBottom, QtCore.Qt.black)

    ui_p = picture()
    ui_p.show()
    splash.close()

    sys.exit(app.exec_())

想给系统起什么名字自己更换以及背景

然后将YOLOX demo.py文件移动至根目录下,并将下面内容复制过去:

import argparse
import os
import time
from loguru import logger

import cv2

import torch

from yolox.data.data_augment import ValTransform
from yolox.data.datasets import COCO_CLASSES
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, vis

IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]


def make_parser():
    parser = argparse.ArgumentParser("YOLOX Demo!")
    parser.add_argument(
        "--demo", default="image", help="demo type, eg. image, video and webcam"
    )
    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
    parser.add_argument("-n", "--name", type=str, default=None, help="model name")

    parser.add_argument(
        "--path", default="./assets/dog.jpg", help="path to images or video"
    )
    parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
    parser.add_argument(
        "--save_result",
        action="store_true",
        help="whether to save the inference result of image/video",
    )

    # exp file
    parser.add_argument(
        "-f",
        "--exp_file",
        default='exps/default/yolox_s.py',
        type=str,
        help="please input your experiment description file",
    )
    parser.add_argument("-c", "--ckpt", default='yolox_s.pth', type=str, help="ckpt for eval")
    parser.add_argument(
        "--device",
        default="cpu",
        type=str,
        help="device to run our model, can either be cpu or gpu",
    )
    parser.add_argument("--conf", default=0.01, type=float, help="test conf")
    parser.add_argument("--nms", default=0.45, type=float, help="test nms threshold")
    parser.add_argument("--tsize", default=640, type=int, help="test img size")
    parser.add_argument(
        "--fp16",
        dest="fp16",
        default=False,
        action="store_true",
        help="Adopting mix precision evaluating.",
    )
    parser.add_argument(
        "--legacy",
        dest="legacy",
        default=False,
        action="store_true",
        help="To be compatible with older versions",
    )
    parser.add_argument(
        "--fuse",
        dest="fuse",
        default=False,
        action="store_true",
        help="Fuse conv and bn for testing.",
    )
    parser.add_argument(
        "--trt",
        dest="trt",
        default=False,
        action="store_true",
        help="Using TensorRT model for testing.",
    )
    return parser


def get_image_list(path):
    image_names = []
    for maindir, subdir, file_name_list in os.walk(path):
        for filename in file_name_list:
            apath = os.path.join(maindir, filename)
            ext = os.path.splitext(apath)[1]
            if ext in IMAGE_EXT:
                image_names.append(apath)
    return image_names


class Predictor(object):
    def __init__(
        self,
        model,
        exp,
        cls_names=COCO_CLASSES,
        trt_file=None,
        decoder=None,
        device="cpu",
        fp16=False,
        legacy=False,
    ):
        self.model = model
        self.cls_names = cls_names
        self.decoder = decoder
        self.num_classes = exp.num_classes
        self.confthre = exp.test_conf
        self.nmsthre = exp.nmsthre
        self.test_size = exp.test_size
        self.device = device
        self.fp16 = fp16
        self.preproc = ValTransform(legacy=legacy)
        if trt_file is not None:
            from torch2trt import TRTModule

            model_trt = TRTModule()
            model_trt.load_state_dict(torch.load(trt_file))

            x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
            self.model(x)
            self.model = model_trt

    def inference(self, img):
        # img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        img_info = {"id": 0}
        # if isinstance(img, str):
        #     img_info["file_name"] = os.path.basename(img)
        #     img = cv2.imread(img)
        # else:
        img_info["file_name"] = None
        img= cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        height, width = img.shape[:2]
        img_info["height"] = height
        img_info["width"] = width
        img_info["raw_img"] = img

        ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
        img_info["ratio"] = ratio

        img, _ = self.preproc(img, None, self.test_size)
        img = torch.from_numpy(img).unsqueeze(0)
        img = img.float()
        if self.device == "gpu":
            img = img.cuda()
            if self.fp16:
                img = img.half()  # to FP16

        with torch.no_grad():
            t0 = time.time()
            outputs = self.model(img)
            if self.decoder is not None:
                outputs = self.decoder(outputs, dtype=outputs.type())
            outputs = postprocess(
                outputs, self.num_classes, self.confthre,
                self.nmsthre, class_agnostic=True
            )
            logger.info("Infer time: {:.4f}s".format(time.time() - t0))
        return outputs, img_info,time.time() - t0

    def visual(self, output, img_info, cls_conf=0.35):
        ratio = img_info["ratio"]
        img = img_info["raw_img"]
        if output is None:
            return img,0
        output = output.cpu()

        bboxes = output[:, 0:4]

        # preprocessing: resize
        bboxes /= ratio

        cls = output[:, 6]
        scores = output[:, 4] * output[:, 5]

        vis_res,k = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
        return vis_res,k


def image_demo(predictor,current_time,image):
    # if os.path.isdir(path):
    #     files = get_image_list(path)
    # else:
    #     files = [path]
    # files.sort()
    # for image_name in files:
        outputs, img_info,ti = predictor.inference(image)
        result_image,nums = predictor.visual(outputs[0], img_info, predictor.confthre)
        return result_image,nums,ti
        # if save_result:
        #     save_folder = os.path.join(
        #         vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
        #     )
        #     os.makedirs(save_folder, exist_ok=True)
        #     save_file_name = os.path.join(save_folder, os.path.basename(image_name))
        #     logger.info("Saving detection result in {}".format(save_file_name))
        #     cv2.imwrite(save_file_name, result_image)
        # ch = cv2.waitKey(0)
        # if ch == 27 or ch == ord("q") or ch == ord("Q"):
        #     break


def imageflow_demo(predictor, vis_folder, current_time, args):
    # cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
    # width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
    # height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
    # fps = cap.get(cv2.CAP_PROP_FPS)
    # if args.save_result:
    #     save_folder = os.path.join(
    #         vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
    #     )
    #     os.makedirs(save_folder, exist_ok=True)
    #     if args.demo == "video":
    #         save_path = os.path.join(save_folder, os.path.basename(args.path))
    #     else:
    #         save_path = os.path.join(save_folder, "camera.mp4")
    #     logger.info(f"video save_path is {save_path}")
    #     vid_writer = cv2.VideoWriter(
    #         save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
    #     )
    while True:
        ret_val, frame = cap.read()
        if ret_val:
            outputs, img_info = predictor.inference(frame)
            result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)
            if args.save_result:
                vid_writer.write(result_frame)
            else:
                cv2.namedWindow("yolox", cv2.WINDOW_NORMAL)
                cv2.imshow("yolox", result_frame)
            ch = cv2.waitKey(1)
            if ch == 27 or ch == ord("q") or ch == ord("Q"):
                break
        else:
            break


class main(object):
    def __init__(self):
        args = make_parser().parse_args()
        exp = get_exp(args.exp_file, args.name)
        if not args.experiment_name:
            args.experiment_name = exp.exp_name

        file_name = os.path.join(exp.output_dir, args.experiment_name)
        os.makedirs(file_name, exist_ok=True)

        # vis_folder = None
        # if args.save_result:
        #     vis_folder = os.path.join(file_name, "vis_res")
        #     os.makedirs(vis_folder, exist_ok=True)

        if args.trt:
            args.device = "gpu"

        logger.info("Args: {}".format(args))

        if args.conf is not None:
            exp.test_conf = args.conf
        if args.nms is not None:
            exp.nmsthre = args.nms
        if args.tsize is not None:
            exp.test_size = (args.tsize, args.tsize)

        model = exp.get_model()
        logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))

        if args.device == "gpu":
            model.cuda()
            if args.fp16:
                model.half()  # to FP16
        model.eval()

        if not args.trt:
            if args.ckpt is None:
                ckpt_file = os.path.join(file_name, "best_ckpt.pth")
            else:
                ckpt_file = args.ckpt
            logger.info("loading checkpoint")
            ckpt = torch.load(ckpt_file, map_location="cpu")
            # load the model state dict
            model.load_state_dict(ckpt["model"])
            logger.info("loaded checkpoint done.")

        if args.fuse:
            logger.info("\tFusing model...")
            model = fuse_model(model)

        if args.trt:
            assert not args.fuse, "TensorRT model is not support model fusing!"
            trt_file = os.path.join(file_name, "model_trt.pth")
            assert os.path.exists(
                trt_file
            ), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
            model.head.decode_in_inference = False
            decoder = model.head.decode_outputs
            logger.info("Using TensorRT to inference")
        else:
            trt_file = None
            decoder = None

        self.predictor = Predictor(
            model, exp, COCO_CLASSES, trt_file, decoder,
            args.device, args.fp16, args.legacy,
        )
    def demoimg(self,img):
        current_time = time.localtime()

        im=image_demo(self.predictor,current_time,img)
        return im
    def demovido(self,img):
        imageflow_demo(predictor, img, current_time)


if __name__ == "__main__":
    args = make_parser().parse_args()
    exp = get_exp(args.exp_file, args.name)

    main(exp, args)

需要的参数直接在上面改好,由于这里我没有单独训练猫狗数据集直接利用的YOLOX-s的权重文件,然后需要将索引更改为猫和狗的分类索引,更改visualize.py文件

def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
    k = 0
    for i in range(len(boxes)):
        box = boxes[i]
        cls_id = int(cls_ids[i])
        if cls_id in (15,16):
            score = scores[i]
            if score < conf:
                continue
            k+=1
            x0 = int(box[0])
            y0 = int(box[1])
            x1 = int(box[2])
            y1 = int(box[3])

            color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
            text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
            txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
            font = cv2.FONT_HERSHEY_SIMPLEX

            txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
            cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)

            txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
            cv2.rectangle(
                img,
                (x0, y0 + 1),
                (x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
                txt_bk_color,
                -1
            )
            cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)

    return img,k

因为希望知道数量所以这里设置了个参数k并进行return

最后运行一下之前的pyqt5文件就可以进行检测了:

示例如下:

 利用PYQT5结合YOLOX搭建检测系统_第1张图片

 

你可能感兴趣的:(python,开发语言)