今天给大家分享用pyqt5桌面小组件搭建一个检测系统,暂定为公共场合猫狗检测系统,检测算法为YOLOX。后续会更新YOLOv8+pyqt5教程
该系统可以进行图片检测,实时检测,视频检测
首先创建一个.py文件复制下面代码:
from PIL import Image
import numpy as np
import time
import os
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtGui import *
import cv2
import sys
from PyQt5.QtWidgets import *
# from detect_qt5 import main_detect,my_lodelmodel
from demo import main
'''摄像头和视频实时检测界面'''
class Ui_MainWindow(QWidget):
def __init__(self, parent=None):
super(Ui_MainWindow, self).__init__(parent)
# self.face_recong = face.Recognition()
self.timer_camera1 = QtCore.QTimer()
self.timer_camera2 = QtCore.QTimer()
self.timer_camera3 = QtCore.QTimer()
self.timer_camera4 = QtCore.QTimer()
self.cap = cv2.VideoCapture()
self.CAM_NUM = 0
# self.slot_init()
self.__flag_work = 0
self.x = 0
self.count = 0
self.setWindowTitle("公共场合猫狗检测系统")
self.setWindowIcon(QIcon(os.getcwd() + '\\data\\source_image\\Detective.ico'))
self.setFixedSize(1600, 900)
self.yolo=main()
# self.my_model = my_lodelmodel()
self.button_open_camera = QPushButton(self)
self.button_open_camera.setText(u'打开摄像头')
self.button_open_camera.setStyleSheet('''
QPushButton
{text-align : center;
background-color : white;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
QPushButton:pressed
{text-align : center;
background-color : light gray;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
''')
self.button_open_camera.move(10, 40)
self.button_open_camera.clicked.connect(self.button_open_camera_click)
# self.button_open_camera.clicked.connect(self.button_open_camera_click1)
# btn.clicked.connect(self.openimage)
self.btn1 = QPushButton(self)
self.btn1.setText("检测摄像头")
self.btn1.setStyleSheet('''
QPushButton
{text-align : center;
background-color : white;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
QPushButton:pressed
{text-align : center;
background-color : light gray;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
''')
self.btn1.move(10, 80)
self.btn1.clicked.connect(self.button_open_camera_click1)
# print("QPushButton构建")
self.open_video = QPushButton(self)
self.open_video.setText("打开视频")
self.open_video.setStyleSheet('''
QPushButton
{text-align : center;
background-color : white;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
QPushButton:pressed
{text-align : center;
background-color : light gray;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
''')
self.open_video.move(10, 160)
self.open_video.clicked.connect(self.open_video_button)
print("QPushButton构建")
self.btn1 = QPushButton(self)
self.btn1.setText("检测视频文件")
self.btn1.setStyleSheet('''
QPushButton
{text-align : center;
background-color : white;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
QPushButton:pressed
{text-align : center;
background-color : light gray;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
''')
self.btn1.move(10, 200)
self.btn1.clicked.connect(self.detect_video)
print("QPushButton构建")
# btn1.clicked.connect(self.detect())
# btn1.clicked.connect(self.button1_test)
# btn1.clicked.connect(self.detect())
# btn1.clicked.connect(self.button1_test)
btn2 = QPushButton(self)
btn2.setText("返回上一界面")
btn2.setStyleSheet('''
QPushButton
{text-align : center;
background-color : white;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
QPushButton:pressed
{text-align : center;
background-color : light gray;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
''')
btn2.move(10, 240)
btn2.clicked.connect(self.back_lastui)
# 信息显示
self.label_show_camera = QLabel(self)
self.label_move = QLabel()
self.label_move.setFixedSize(100, 100)
# self.label_move.setText(" 11 待检测图片")
self.label_show_camera.setFixedSize(700, 500)
self.label_show_camera.setAutoFillBackground(True)
self.label_show_camera.move(110, 80)
self.label_show_camera.setStyleSheet("QLabel{background:#F5F5DC;}"
"QLabel{color:rgb(300,300,300,120);font-size:10px;font-weight:bold;font-family:宋体;}"
)
self.label_show_camera1 = QLabel(self)
self.label_show_camera1.setFixedSize(700, 500)
self.label_show_camera1.setAutoFillBackground(True)
self.label_show_camera1.move(850, 80)
self.label_show_camera1.setStyleSheet("QLabel{background:#F5F5DC;}"
"QLabel{color:rgb(300,300,300,120);font-size:10px;font-weight:bold;font-family:宋体;}"
)
self.timer_camera1.timeout.connect(self.show_camera)
self.timer_camera2.timeout.connect(self.show_camera1)
# self.timer_camera3.timeout.connect(self.show_camera2)
self.timer_camera4.timeout.connect(self.show_camera2)
self.timer_camera4.timeout.connect(self.show_camera3)
self.clicked = False
# self.setWindowTitle(u'摄像头')
self.frame_s = 3
# 设置背景图片
palette1 = QPalette()
palette1.setBrush(self.backgroundRole(), QBrush(QPixmap('R-C.png')))
self.setPalette(palette1)
def back_lastui(self):
self.timer_camera1.stop()
self.cap.release()
self.label_show_camera.clear()
self.timer_camera2.stop()
self.label_show_camera1.clear()
cam_t.close()
ui_p.show()
'''摄像头'''
def button_open_camera_click(self):
if self.timer_camera1.isActive() == False:
flag = self.cap.open(self.CAM_NUM)
if flag == False:
msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
buttons=QtWidgets.QMessageBox.Ok,
defaultButton=QtWidgets.QMessageBox.Ok)
else:
self.timer_camera1.start(30)
self.button_open_camera.setText(u'关闭摄像头')
else:
self.timer_camera1.stop()
self.cap.release()
self.label_show_camera.clear()
self.timer_camera2.stop()
self.label_show_camera1.clear()
self.button_open_camera.setText(u'打开摄像头')
def show_camera(self): # 摄像头左边
flag, self.image = self.cap.read()
dir_path = os.getcwd()
camera_source = dir_path + "\\data\\test\\2.jpg"
cv2.imwrite(camera_source, self.image)
width = self.image.shape[1]
height = self.image.shape[0]
# 设置新的图片分辨率框架
width_new = 700
height_new = 500
# 判断图片的长宽比率
if width / height >= width_new / height_new:
show = cv2.resize(self.image, (width_new, int(height * width_new / width)))
else:
show = cv2.resize(self.image, (int(width * height_new / height), height_new))
show = cv2.cvtColor(show, cv2.COLOR_BGR2RGB)
showImage = QtGui.QImage(show.data, show.shape[1], show.shape[0], 3 * show.shape[1], QtGui.QImage.Format_RGB888)
self.label_show_camera.setPixmap(QtGui.QPixmap.fromImage(showImage))
def button_open_camera_click1(self):
if self.timer_camera2.isActive() == False:
flag = self.cap.open(self.CAM_NUM)
if flag == False:
msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
buttons=QtWidgets.QMessageBox.Ok,
defaultButton=QtWidgets.QMessageBox.Ok)
else:
self.timer_camera2.start(30)
self.button_open_camera.setText(u'关闭摄像头')
else:
self.timer_camera2.stop()
self.cap.release()
self.label_show_camera1.clear()
self.button_open_camera.setText(u'打开摄像头')
def show_camera1(self):
fps = 0.0
t1 = time.time()
flag, self.image = self.cap.read()
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
# self.image = Image.fromarray(np.uint8(self.image))
im0, nums, ti = self.yolo.demoimg(self.image)
im0= cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)
width = im0.shape[1]
height = im0.shape[0]
# 设置新的图片分辨率框架
width_new = 640
height_new = 640
# 判断图片的长宽比率
if width / height >= width_new / height_new:
show = cv2.resize(im0, (width_new, int(height * width_new / width)))
else:
show = cv2.resize(im0, (int(width * height_new / height), height_new))
# im0 = cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
if nums>= 1:
fps = (fps + (1. / (time.time() - t1))) / 2
im0 = cv2.putText(im0, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
im0 = cv2.putText(im0, "No pets allowed", (0, 150), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
showImage = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
self.label_show_camera1.setPixmap(QtGui.QPixmap.fromImage(showImage))
'''视频检测'''
def open_video_button(self):
if self.timer_camera4.isActive() == False:
imgName, imgType = QFileDialog.getOpenFileName(self, "打开视频", "", "*.mp4;;*.AVI;;*.rmvb;;All Files(*)")
self.cap_video = cv2.VideoCapture(imgName)
flag = self.cap_video.isOpened()
if flag == False:
msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
buttons=QtWidgets.QMessageBox.Ok,
defaultButton=QtWidgets.QMessageBox.Ok)
else:
# self.timer_camera3.start(10)
self.show_camera2()
self.open_video.setText(u'关闭视频')
else:
# self.timer_camera3.stop()
self.cap_video.release()
self.label_show_camera.clear()
self.timer_camera4.stop()
self.frame_s = 3
self.label_show_camera1.clear()
self.open_video.setText(u'打开视频')
def detect_video(self):
if self.timer_camera4.isActive() == False:
flag = self.cap_video.isOpened()
if flag == False:
msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
buttons=QtWidgets.QMessageBox.Ok,
defaultButton=QtWidgets.QMessageBox.Ok)
else:
self.timer_camera4.start(30)
else:
self.timer_camera4.stop()
self.cap_video.release()
self.label_show_camera1.clear()
def show_camera2(self): # 显示视频的左边
# 抽帧
length = int(self.cap_video.get(cv2.CAP_PROP_FRAME_COUNT)) # 抽帧
print(self.frame_s, length) # 抽帧
flag, self.image1 = self.cap_video.read() # image1是视频的
if flag == True:
width = self.image1.shape[1]
height = self.image1.shape[0]
# 设置新的图片分辨率框架
width_new = 700
height_new = 500
# 判断图片的长宽比率
if width / height >= width_new / height_new:
show = cv2.resize(self.image1, (width_new, int(height * width_new / width)))
else:
show = cv2.resize(self.image1, (int(width * height_new / height), height_new))
show = cv2.cvtColor(show, cv2.COLOR_BGR2RGB)
showImage = QtGui.QImage(show.data, show.shape[1], show.shape[0], 3 * show.shape[1],
QtGui.QImage.Format_RGB888)
self.label_show_camera.setPixmap(QtGui.QPixmap.fromImage(showImage))
else:
self.cap_video.release()
self.label_show_camera.clear()
self.timer_camera4.stop()
self.label_show_camera1.clear()
self.open_video.setText(u'打开视频')
def show_camera3(self):
flag, self.image1 = self.cap_video.read()
self.frame_s += 1
if flag == True:
# if self.frame_s % 3 == 0: #抽帧
# face = self.face_detect.align(self.image)
# if face:
# pass
# dir_path = os.getcwd()
# camera_source = dir_path + "\\data\\test\\video.jpg"
#
# cv2.imwrite(camera_source, self.image1)
# print("im01")
# im0, label = main_detect(self.my_model, camera_source)
im0,nums,ti = self.yolo.demoimg(self.image1)
# print("imo",im0)
# print(label)
# if label == 'debug':
# print("labelkong")
# print("debug")
# im0, label = slef.detect()
# print("debug1")
width = im0.shape[1]
height = im0.shape[0]
# 设置新的图片分辨率框架
width_new = 700
height_new = 500
# 判断图片的长宽比率
if width / height >= width_new / height_new:
show = cv2.resize(im0, (width_new, int(height * width_new / width)))
else:
show = cv2.resize(im0, (int(width * height_new / height), height_new))
im0 = show#cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
# print("debug2")
if nums >= 1:
im0 = cv2.putText(im0, "Warning", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
im0 = cv2.putText(im0, f"nums:{nums}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
showImage = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
self.label_show_camera1.setPixmap(QtGui.QPixmap.fromImage(showImage))
'''单张图片检测'''
class picture(QWidget):
def __init__(self):
super(picture, self).__init__()
self.str_name = '0'
self.yolo = main()
# self.my_model=my_lodelmodel()
self.resize(1600, 900)
self.setWindowIcon(QIcon(os.getcwd() + '\\data\\source_image\\Detective.ico'))
self.setWindowTitle("公共场合猫狗检测系统")
# window_pale = QtGui.QPalette()
# window_pale.setBrush(self.backgroundRole(), QtGui.QBrush(
# QtGui.QPixmap(os.getcwd() + '\\data\\source_image\\backgroud.jpg')))
# self.setPalette(window_pale)
palette2 = QPalette()
palette2.setBrush(self.backgroundRole(), QBrush(QPixmap('4.jpg')))
self.setPalette(palette2)
camera_or_video_save_path = 'data\\test'
if not os.path.exists(camera_or_video_save_path):
os.makedirs(camera_or_video_save_path)
self.label1 = QLabel(self)
self.label1.setText(" 待检测图片")
self.label1.setFixedSize(700, 500)
self.label1.move(110, 80)
self.label1.setStyleSheet("QLabel{background:#7A6969;}"
"QLabel{color:rgb(300,300,300,120);font-size:20px;font-weight:bold;font-family:宋体;}"
)
self.label2 = QLabel(self)
self.label2.setText("检测结果")
self.label2.setFixedSize(700, 500)
self.label2.move(850, 80)
self.label2.setStyleSheet("QLabel{background:#7A6969;}"
"QLabel{color:rgb(300,300,300,120);font-size:20px;font-weight:bold;font-family:宋体;}"
)
self.label3 = QLabel(self)
self.label3.setText("")
self.label3.move(1200, 620)
self.label3.setStyleSheet("font-size:20px;")
self.label3.adjustSize()
btn = QPushButton(self)
btn.setText("打开图片")
btn.setStyleSheet('''
QPushButton
{text-align : center;
background-color : white;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
QPushButton:pressed
{text-align : center;
background-color : light gray;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
''')
btn.move(10, 30)
btn.clicked.connect(self.openimage)
btn1 = QPushButton(self)
btn1.setText("检测图片")
btn1.setStyleSheet('''
QPushButton
{text-align : center;
background-color : white;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
QPushButton:pressed
{text-align : center;
background-color : light gray;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
''')
btn1.move(10, 80)
# print("QPushButton构建")
btn1.clicked.connect(self.button1_test)
btn3 = QPushButton(self)
btn3.setText("视频和摄像头检测")
btn3.setStyleSheet('''
QPushButton
{text-align : center;
background-color : white;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
QPushButton:pressed
{text-align : center;
background-color : light gray;
font: bold;
border-color: gray;
border-width: 2px;
border-radius: 10px;
padding: 6px;
height : 14px;
border-style: outset;
font : 14px;}
''')
btn3.move(10, 160)
btn3.clicked.connect(self.camera_find)
self.imgname1 = '0'
def camera_find(self):
ui_p.close()
cam_t.show()
def openimage(self):
imgName, imgType = QFileDialog.getOpenFileName(self, "打开图片", "D://",
"Image files (*.jpg *.gif *.png *.jpeg)") # "*.jpg;;*.png;;All Files(*)"
if imgName != '':
self.imgname1 = imgName
# print("imgName",imgName,type(imgName))
self.im0 = cv2.imread(imgName)
width = self.im0.shape[1]
height = self.im0.shape[0]
# 设置新的图片分辨率框架
width_new = 700
height_new = 500
# 判断图片的长宽比率
if width / height >= width_new / height_new:
show = cv2.resize(self.im0, (width_new, int(height * width_new / width)))
else:
show = cv2.resize(self.im0, (int(width * height_new / height), height_new))
im0 = cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
showImage = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
self.label1.setPixmap(QtGui.QPixmap.fromImage(showImage))
# jpg = QtGui.QPixmap(imgName).scaled(self.label1.width(), self.label1.height())
# self.label1.setPixmap(jpg)
def button1_test(self):
if self.imgname1 != '0':
# QApplication.processEvents()
# image = Image.open(self.imgname1)
image = cv2.imread(self.imgname1)
# K, im0 = self.yolo.detect_image(image)
im0,nums,time=self.yolo.demoimg(image)
print(nums)
# im0 = np.array(im0)
# QApplication.processEvents()
width = im0.shape[1]
height = im0.shape[0]
# 设置新的图片分辨率框架
width_new = 700
height_new = 700
# 判断图片的长宽比率
if width / height >= width_new / height_new:
im0 = cv2.resize(im0, (width_new, int(height * width_new / width)))
else:
im0 = cv2.resize(im0, (int(width * height_new / height), height_new))
im0 = cv2.putText(im0, f"Infertime:{round(time,2)}s", (410, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
# im0 = cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
if nums >= 1:
im0 = cv2.putText(im0, "Warning", (410, 20), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
im0 = cv2.putText(im0, f"nums:{nums}", (410, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
image_name = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
# label=label.split(' ')[0] #label 59 0.96 分割字符串 取前一个
self.label2.setPixmap(QtGui.QPixmap.fromImage(image_name))
# jpg = QtGui.QPixmap(image_name).scaled(self.label1.width(), self.label1.height())
# self.label2.setPixmap(jpg)
else:
QMessageBox.information(self, '错误', '请先选择一个图片文件', QMessageBox.Yes, QMessageBox.Yes)
if __name__ == '__main__':
app = QApplication(sys.argv)
splash = QSplashScreen(QPixmap(".\\data\\source_image\\logo.png"))
# 设置画面中的文字的字体
splash.setFont(QFont('Microsoft YaHei UI', 12))
# 显示画面
splash.show()
# 显示信息
splash.showMessage("程序初始化中... 0%", QtCore.Qt.AlignLeft | QtCore.Qt.AlignBottom, QtCore.Qt.black)
time.sleep(0.3)
splash.showMessage("正在加载模型配置文件...60%", QtCore.Qt.AlignLeft | QtCore.Qt.AlignBottom, QtCore.Qt.black)
cam_t = Ui_MainWindow()
splash.showMessage("正在加载模型配置文件...100%", QtCore.Qt.AlignLeft | QtCore.Qt.AlignBottom, QtCore.Qt.black)
ui_p = picture()
ui_p.show()
splash.close()
sys.exit(app.exec_())
想给系统起什么名字自己更换以及背景
然后将YOLOX demo.py文件移动至根目录下,并将下面内容复制过去:
import argparse
import os
import time
from loguru import logger
import cv2
import torch
from yolox.data.data_augment import ValTransform
from yolox.data.datasets import COCO_CLASSES
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, vis
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
def make_parser():
parser = argparse.ArgumentParser("YOLOX Demo!")
parser.add_argument(
"--demo", default="image", help="demo type, eg. image, video and webcam"
)
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
parser.add_argument(
"--path", default="./assets/dog.jpg", help="path to images or video"
)
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
"--save_result",
action="store_true",
help="whether to save the inference result of image/video",
)
# exp file
parser.add_argument(
"-f",
"--exp_file",
default='exps/default/yolox_s.py',
type=str,
help="please input your experiment description file",
)
parser.add_argument("-c", "--ckpt", default='yolox_s.pth', type=str, help="ckpt for eval")
parser.add_argument(
"--device",
default="cpu",
type=str,
help="device to run our model, can either be cpu or gpu",
)
parser.add_argument("--conf", default=0.01, type=float, help="test conf")
parser.add_argument("--nms", default=0.45, type=float, help="test nms threshold")
parser.add_argument("--tsize", default=640, type=int, help="test img size")
parser.add_argument(
"--fp16",
dest="fp16",
default=False,
action="store_true",
help="Adopting mix precision evaluating.",
)
parser.add_argument(
"--legacy",
dest="legacy",
default=False,
action="store_true",
help="To be compatible with older versions",
)
parser.add_argument(
"--fuse",
dest="fuse",
default=False,
action="store_true",
help="Fuse conv and bn for testing.",
)
parser.add_argument(
"--trt",
dest="trt",
default=False,
action="store_true",
help="Using TensorRT model for testing.",
)
return parser
def get_image_list(path):
image_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in IMAGE_EXT:
image_names.append(apath)
return image_names
class Predictor(object):
def __init__(
self,
model,
exp,
cls_names=COCO_CLASSES,
trt_file=None,
decoder=None,
device="cpu",
fp16=False,
legacy=False,
):
self.model = model
self.cls_names = cls_names
self.decoder = decoder
self.num_classes = exp.num_classes
self.confthre = exp.test_conf
self.nmsthre = exp.nmsthre
self.test_size = exp.test_size
self.device = device
self.fp16 = fp16
self.preproc = ValTransform(legacy=legacy)
if trt_file is not None:
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(trt_file))
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
self.model(x)
self.model = model_trt
def inference(self, img):
# img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img_info = {"id": 0}
# if isinstance(img, str):
# img_info["file_name"] = os.path.basename(img)
# img = cv2.imread(img)
# else:
img_info["file_name"] = None
img= cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img
ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
img_info["ratio"] = ratio
img, _ = self.preproc(img, None, self.test_size)
img = torch.from_numpy(img).unsqueeze(0)
img = img.float()
if self.device == "gpu":
img = img.cuda()
if self.fp16:
img = img.half() # to FP16
with torch.no_grad():
t0 = time.time()
outputs = self.model(img)
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs, self.num_classes, self.confthre,
self.nmsthre, class_agnostic=True
)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info,time.time() - t0
def visual(self, output, img_info, cls_conf=0.35):
ratio = img_info["ratio"]
img = img_info["raw_img"]
if output is None:
return img,0
output = output.cpu()
bboxes = output[:, 0:4]
# preprocessing: resize
bboxes /= ratio
cls = output[:, 6]
scores = output[:, 4] * output[:, 5]
vis_res,k = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
return vis_res,k
def image_demo(predictor,current_time,image):
# if os.path.isdir(path):
# files = get_image_list(path)
# else:
# files = [path]
# files.sort()
# for image_name in files:
outputs, img_info,ti = predictor.inference(image)
result_image,nums = predictor.visual(outputs[0], img_info, predictor.confthre)
return result_image,nums,ti
# if save_result:
# save_folder = os.path.join(
# vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
# )
# os.makedirs(save_folder, exist_ok=True)
# save_file_name = os.path.join(save_folder, os.path.basename(image_name))
# logger.info("Saving detection result in {}".format(save_file_name))
# cv2.imwrite(save_file_name, result_image)
# ch = cv2.waitKey(0)
# if ch == 27 or ch == ord("q") or ch == ord("Q"):
# break
def imageflow_demo(predictor, vis_folder, current_time, args):
# cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
# width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
# height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
# fps = cap.get(cv2.CAP_PROP_FPS)
# if args.save_result:
# save_folder = os.path.join(
# vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
# )
# os.makedirs(save_folder, exist_ok=True)
# if args.demo == "video":
# save_path = os.path.join(save_folder, os.path.basename(args.path))
# else:
# save_path = os.path.join(save_folder, "camera.mp4")
# logger.info(f"video save_path is {save_path}")
# vid_writer = cv2.VideoWriter(
# save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
# )
while True:
ret_val, frame = cap.read()
if ret_val:
outputs, img_info = predictor.inference(frame)
result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)
if args.save_result:
vid_writer.write(result_frame)
else:
cv2.namedWindow("yolox", cv2.WINDOW_NORMAL)
cv2.imshow("yolox", result_frame)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break
class main(object):
def __init__(self):
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
if not args.experiment_name:
args.experiment_name = exp.exp_name
file_name = os.path.join(exp.output_dir, args.experiment_name)
os.makedirs(file_name, exist_ok=True)
# vis_folder = None
# if args.save_result:
# vis_folder = os.path.join(file_name, "vis_res")
# os.makedirs(vis_folder, exist_ok=True)
if args.trt:
args.device = "gpu"
logger.info("Args: {}".format(args))
if args.conf is not None:
exp.test_conf = args.conf
if args.nms is not None:
exp.nmsthre = args.nms
if args.tsize is not None:
exp.test_size = (args.tsize, args.tsize)
model = exp.get_model()
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
if args.device == "gpu":
model.cuda()
if args.fp16:
model.half() # to FP16
model.eval()
if not args.trt:
if args.ckpt is None:
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
logger.info("loading checkpoint")
ckpt = torch.load(ckpt_file, map_location="cpu")
# load the model state dict
model.load_state_dict(ckpt["model"])
logger.info("loaded checkpoint done.")
if args.fuse:
logger.info("\tFusing model...")
model = fuse_model(model)
if args.trt:
assert not args.fuse, "TensorRT model is not support model fusing!"
trt_file = os.path.join(file_name, "model_trt.pth")
assert os.path.exists(
trt_file
), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
model.head.decode_in_inference = False
decoder = model.head.decode_outputs
logger.info("Using TensorRT to inference")
else:
trt_file = None
decoder = None
self.predictor = Predictor(
model, exp, COCO_CLASSES, trt_file, decoder,
args.device, args.fp16, args.legacy,
)
def demoimg(self,img):
current_time = time.localtime()
im=image_demo(self.predictor,current_time,img)
return im
def demovido(self,img):
imageflow_demo(predictor, img, current_time)
if __name__ == "__main__":
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
main(exp, args)
需要的参数直接在上面改好,由于这里我没有单独训练猫狗数据集直接利用的YOLOX-s的权重文件,然后需要将索引更改为猫和狗的分类索引,更改visualize.py文件
def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
k = 0
for i in range(len(boxes)):
box = boxes[i]
cls_id = int(cls_ids[i])
if cls_id in (15,16):
score = scores[i]
if score < conf:
continue
k+=1
x0 = int(box[0])
y0 = int(box[1])
x1 = int(box[2])
y1 = int(box[3])
color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
font = cv2.FONT_HERSHEY_SIMPLEX
txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
cv2.rectangle(
img,
(x0, y0 + 1),
(x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
txt_bk_color,
-1
)
cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
return img,k
因为希望知道数量所以这里设置了个参数k并进行return
最后运行一下之前的pyqt5文件就可以进行检测了:
示例如下: