YOLOv5 修改detect模块以方便调用(单类型目标)

        前言:本人最近在学着使用yolo,为了方便调用,今天看了下detect.py源码并做了点修改,不过由于训练的模型是单类型的,所以在推理在结果中还没看出哪个数值代表“类型”,所以以下内容只针对单类型的模型,后面琢磨出来了会更新。希望有大佬能指导一下。

在detect.py中添加如下代码块

class yolo_detector:
    def __init__(self,
                 weights='./Weights/last.pt',    # 用train.py训练出的.pt文件
                 imgsz=(640,640),
                 conf_thres=0.25,
                 iou_thres=0.45,
                 half=False,
                 ):
        self.conf_thres = conf_thres
        self.iou_thres = iou_thres

        self.device = select_device('0')
        self.model = DetectMultiBackend(weights, device=self.device)   # 加载模型
        stride, names, pt = self.model.stride, self.model.names, self.model.pt
        self.imgsz = check_img_size(imgsz, s=stride)  # check image size
        half &= pt and self.device.type != 'cpu'  # half precision only supported by PyTorch on CUDA
        self.half = half
        if pt:
            self.model.model.half() if half else self.model.model.float()
        self.view_img = check_imshow()
        cudnn.benchmark = True  # set True to speed up constant image size inference
        self.model.warmup(imgsz=(1, 3, *self.imgsz), half=self.half)

    def run(self, frame):
        # (h, w, c) to (c, h, w)
        b, g, r = cv2.split(frame)
        im0 = numpy.array([b, g, r])

        im = torch.from_numpy(im0).to(self.device)
        im = im.half() if self.half else im.float()  # uint8 to fp16/32
        im /= 255  # 0 - 255 to 0.0 - 1.0
        if len(im.shape) == 3:
            im = im[None]  # expand for batch dim
        pred = self.model(im)
        pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, max_det=5)

        results = []
        for i, det in enumerate(pred):
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(im.shape[2:], det[:, :4], frame.shape).round()

            if det.numel():
                x1, y1, x2, y2 = int(det[0, 0].item()), int(det[0, 1].item()), int(det[0, 2].item()), int(det[0, 3].item())
                lu = (x1, y1)
                rd = (x2, y2)
                results.append((lu, rd))
        return results

       调用实例:

import cv2
import detect_remake
cap=cv2.VideoCapture(0)
a = detect_remake.yolo_detector()
while True:
    rec,img = cap.read()
    results = a.run(img)

    if results:
        for i, pts in enumerate(results):
            cv2.rectangle(img, pts[0], pts[1], (0, 0, 255), 2)
    cv2.imshow("video",img)

    if cv2.waitKey(1)==ord('q'):
        break

你可能感兴趣的:(深度学习,pytorch,计算机视觉)