YOLOV5 Detetct.py 流程分析

文章目录

  • 前言
  • 初始化
  • 加载数据
  • 预测
    • 预测数据格式
  • 后置处理
  • 完整注释后代码

前言

今天放松一下,随便看看这个YOLOV5 的识别部分的代码是怎么做的,先前的话我们自己手动实现了一个非常简易的分类框架,HuClassfiy(已经上传Gitee,方便各位访问),那么这里的话想要使用YOLOV5做点好玩的,也必须要对整个的代码流程进行梳理。原理就不用说了,老复杂了,所以先从简单的来探索。

我们原来的实现这个detect的代码非常简单,后面会贴出,我注释后的detect代码

import argparse
from PIL import Image
from utils.DataSet.MyDataSet import MyDataSet
from utils.DataSet.TransformAtions import TransFormAtions

"""
这里不想写那么多东西,就是简单地去做一个测试就ok了。
其实做法就是在那个train里面的训练
"""

import argparse
import torch
from torch.utils.data import DataLoader
from models.LeNet import LeNet
from data.ModelConfig import *
import outProcess
def detect():


    ways = opt.valid_imgs
    transformations = TransFormAtions()

    net = LeNet(classes=Classes)
    state_dict_load = torch.load(opt.path_state_dict)
    net.load_state_dict(state_dict_load)

    if(ways):

        test_data = MyDataSet(data_dir=opt.valid_dir, transform=transformations.valid_transform)
        valid_loader = DataLoader(dataset=test_data, batch_size=1)

        net.eval()
        with torch.no_grad():
            for i, data in enumerate(valid_loader):
                # forward
                inputs, labels = data
                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                # 输出处理器

                outProcess.Function(predicted.numpy()[0])
    else:
        #指定的是单张图片,少给我来奇奇怪怪的输入,这个版本容错很差滴!!!
        path_img = opt.valid_dir
        if(".jpg" not in path_img):
            raise Exception("小爷打不开这图片")
        image = Image.open(path_img)
        image = transformations.valid_transform(image)
        image = torch.reshape(image, (1, 3, 32, 32))

        net.eval()
        with torch.no_grad():
            out = net(image)

            outProcess.Function(out.argmax(1).item())


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # False表示识别单张图片,True表示多张图片,此时指定路径即可。
    parser.add_argument('--valid_imgs',type=bool,default=False)
    parser.add_argument('--valid_dir', type=str, default=r'F:\projects\PythonProject\MyClassfication\mydata\train\100\1.jpg')
    parser.add_argument('--path_state_dict', type=str, default='runs/train/epx2/weights/best.pth')
    opt = parser.parse_args()
    detect()

在YOLO V 5 里面也不复杂,也就比我多了100多行代码。

我们这边大致的流程就三个。
YOLOV5 Detetct.py 流程分析_第1张图片
然后每一个环节都可以有很多细节优化啥的,由于俺们那个是很简陋的,所以没有哈。

好了,我们开始正式进入这个YOLOV5的实际环节。

初始化

我们先来这看到这个环节,这里一共是做了两件事情嘛,读取超参数,加载模型权重文件,加载驱动
YOLOV5 Detetct.py 流程分析_第2张图片
这里可以注意到这个函数
YOLOV5 Detetct.py 流程分析_第3张图片
这个的话不用想的那么复杂,就是这个玩意
YOLOV5 Detetct.py 流程分析_第4张图片
目的就是返回一个 可以正常使用的驱动,要是我写的话,我压根不会管那么多,不行就玩命报错,然后输出日志文件。

加载数据

然后第二步是加载数据,这个说实话,没什么好说的,分两个,一个是读取网络摄像头,一个是读取一张图片,或者视频,本地摄像头。这些逻辑处理细节不一样,但是结果都是一样的。
YOLOV5 Detetct.py 流程分析_第5张图片
就是把数据给我封装的dataset里面,然后读取。

预测

YOLOV5 Detetct.py 流程分析_第6张图片
我的注释写还是挺不错的。

预测数据格式

这里我们说说那个预测的格式。
我这里还是拿上次的一张图片做演示

这里有两个目标框,所以拿到的数据是这样的
在这里插入图片描述
我们发现pred 是一个长度为1,里面有两个list的玩意
之后我们发现最后一个直接是0
这个的话,是这样的
YOLOV5 Detetct.py 流程分析_第7张图片

后置处理

之后就是拿到东西之后处理。在yolo里面默认是实现了一个自己绘图的玩意。
当然有时候,我们不仅仅要这玩意,我们想要实现AI压枪的话还需要那啥。
YOLOV5 Detetct.py 流程分析_第8张图片

完整注释后代码

import argparse
import time
from pathlib import Path

import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
    scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized


def detect(save_img=False):
    # 读取初始化参数
    source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
    save_img = not opt.nosave and not source.endswith('.txt')  # save inference images
    webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
        ('rtsp://', 'rtmp://', 'http://', 'https://'))

    # Directories
    save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))  # increment run
    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

    # Initialize
    set_logging()
    device = select_device(opt.device)
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    #加载模型,这一块,weights是我们传入的参数,是我们权重文件地址
    #注意到这个model就是我们Huclassfiy的net
    model = attempt_load(weights, map_location=device)  # load FP32 model
    stride = int(model.stride.max())  # model stride,维度的变换步长,这个和YOLO的网络结构有关,先忽略
    #imgsz是我们图片资源,对图片尺寸进行检查
    imgsz = check_img_size(imgsz, s=stride)  # check img_size
    #Pytorch 模型加速,这个需要GPU加速,需要先加载模型权重的!!!
    if half:
        model.half()  # to FP16


    # Second-stage classifier
    classify = False
    if classify:
        modelc = load_classifier(name='resnet101', n=2)  # initialize
        modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()

    # Set Dataloader
    vid_path, vid_writer = None, None
    #如果是网络摄像头的数据这样处理
    if webcam:
        view_img = check_imshow()
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=imgsz, stride=stride)
    else:
        #这部分是加载dataset 和我们那个也是类似的,只不过对于单张图片,我们直接转化为了一个tensor在HuClassFiy
        dataset = LoadImages(source, img_size=imgsz, stride=stride)

    # Get names and colors
    names = model.module.names if hasattr(model, 'module') else model.names
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]

    # Run inference
    if device.type != 'cpu':
        model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once
    t0 = time.time()
    for path, img, im0s, vid_cap in dataset:
        #这里就和我们的那个进入验证是类似的了
        #path 是你的图片路径
        # img 自然是image转化为了tensor
        #im0s 是做了一个转化img0 = cv2.imread(path)  # BGR
        #vid_cap 就是说这玩意是不是一个视频,我们读入图片当然不是所以是None

        img = torch.from_numpy(img).to(device)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0 #归一化
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t1 = time_synchronized()
        pred = model(img, augment=opt.augment)[0]
        #这个是预测的结果,但是按照那个网络的工作原理,还需要进行NMS非极大值抑制筛选目标框框

        # Apply NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
        t2 = time_synchronized()
        print("预测结果是",pred)
        #按照我们在YOLV1论文里面的推出,应该是有5个参数
        #x,y,w,h,k可信度,但是这里要显示所以还有一个对应的条件概率
        #所以应该有6个参数,但是对应参数k,我们的概率计算是需要k的,结合参数opt.iou_thres
        #所以此时那个参数k应该是iou,之后对应的概率,这里最后通过debug我发现那个完整的参数是这样的
        #左上角,右下角,然后可信度,然后所属类别,注意那里显示的是按照屏幕100%来的,我的笔记本是125%
        #得到的坐标是需要除以1.25的

        # Apply Classifier
        if classify:
            pred = apply_classifier(pred, modelc, img, im0s)

        # Process detections
        #这部分,就是我们的后置处理了。说实话,应该把这玩意拆开的,这个部分是给Opencv画图用的
        for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
            else:
                p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)

            p = Path(p)  # to Path
            save_path = str(save_dir / p.name)  # img.jpg
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txt
            s += '%gx%g ' % img.shape[2:]  # print string
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results
                for *xyxy, conf, cls in reversed(det):
                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh)  # label format
                        with open(txt_path + '.txt', 'a') as f:
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')

                    if save_img or view_img:  # Add bbox to image
                        label = f'{names[int(cls)]} {conf:.2f}'
                        # plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
                        im0 = plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
            # Print time (inference + NMS)
            print(f'{s}Done. ({t2 - t1:.3f}s)')

            # Stream results
            if view_img:
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)  # 1 millisecond

            # Save results (image with detections)
            if save_img:
                if dataset.mode == 'image':
                    cv2.imwrite(save_path, im0)
                else:  # 'video' or 'stream'
                    if vid_path != save_path:  # new video
                        vid_path = save_path
                        if isinstance(vid_writer, cv2.VideoWriter):
                            vid_writer.release()  # release previous video writer
                        if vid_cap:  # video
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        else:  # stream
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                            save_path += '.mp4'
                        vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
                    vid_writer.write(im0)

    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
        print(f"Results saved to {save_dir}{s}")

    print(f'Done. ({time.time() - t0:.3f}s)')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--weights', nargs='+', type=str, default='runs/train/exp2/weights/best.pt', help='model.pt path(s)')
    # http://admin:[email protected]:8081
    parser.add_argument('--source', type=str, default=r'F:\projects\PythonProject\yolov5-5.0\mydata\images\003.jpg', help='source')  # file/folder, 0 for webcam
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--view-img', action='store_true', help='display results')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--update', action='store_true', help='update all models')
    parser.add_argument('--project', default='runs/detect', help='save results to project/name')
    parser.add_argument('--name', default='exp', help='save results to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    opt = parser.parse_args()
    print(opt)
    check_requirements(exclude=('pycocotools', 'thop'))

    with torch.no_grad():
        if opt.update:  # update all models (to fix SourceChangeWarning)
            for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
                detect()
                strip_optimizer(opt.weights)
        else:
            detect()

那么接下来我们要做的就是提取detect,把这个玩意套在我们自己的项目里面。为了后面便于使用这个yolo,我决定后面对这个玩意进行工程化规范,便于直接进行二次使用,开发。毕竟核心的话其实就和HuClassfiy一样,就那几个块。还是那句话,yolo的难点不在工程上,在原理实现上面…

你可能感兴趣的:(人工智能,python,计算机视觉,pytorch)