yolov5 代码详解

14 yolov5 代码详解

      • 1. 目标检测
        • 1.1 数据集
        • 1.2 模型
        • 1.3 损失函数
        • 1.4 训练 trian.py
        • 1.5 推理 detec.py
      • 2. 分类
        • 2.1 数据集
        • 2.2 模型
        • 2.3 损失函数
        • 2.4 训练 trian.py
        • 2.5 推理 detec.py
      • 3. 分割
        • 3.1 数据集
        • 3.2 模型
        • 3.3 损失函数
        • 3.4 训练 trian.py
        • 3.5 推理 detec.py

目录

  1. 目标检测
    • 数据集
    • 模型
    • 损失函数
    • 训练trian.py
    • 推理detec.py
  2. 分类
    • 数据集
    • 模型
    • 损失函数
    • 训练trian.py
    • 推理detec.py
  3. 分割
    • 数据集
    • 模型
    • 损失函数
    • 训练trian.py
    • 推理detec.py

YOLOV5是一款强大的模型,不仅包含分类、检测、分割(全能选手,谢谢大佬),而且在模型、数据增强、先验框和真实框的匹配、损失函数都有极大的改进。在推理速度和COCOmAP上也比之前的模型有提高。训练策略上也增加一些新的技巧,比如多尺度、rectangle、用遗传算法搜寻超参数。

1. 目标检测

1.1 数据集

  1. 数据集存放格式
--datasets
    -- imags
        -- train
        -- val
        -- test
    -- labels
       -- train
       -- val
       -- test

在datasets文件夹分别存放imags和labels文件夹,imags文件夹存放图片信息(.jpg)。labels文件夹存放对应图片的标签信息(.txt),标签信息包含目标物体的类别和真实框的坐标(cx,cy,w,h),这些坐标都是归一化后的数据。data/coco128.yaml内存储数据地址和类别信息,用于训练需要。

  1. 数据集处理流程
    加载数据集的函数
    train_loader, dataset = create_dataloader(train_path,
                                              imgsz,
                                              batch_size // WORLD_SIZE,
                                              gs,
                                              single_cls,
                                              hyp=hyp,
                                              augment=True,
                                              cache=None if opt.cache == 'val' else opt.cache,
                                              rect=opt.rect,
                                              rank=LOCAL_RANK,
                                              workers=workers,
                                              image_weights=opt.image_weights,
                                              quad=opt.quad,
                                              prefix=colorstr('train: '),
                                              shuffle=True,
                                              seed=opt.seed)
                                            



1.2 模型



1.3 损失函数



1.4 训练 trian.py



1.5 推理 detec.py

检测流程

  1. 解析参数
  2. 分析source类型
  3. 模型运行结果保存路径
  4. 加载模型
  5. 加载数据
  6. 推理:处理图片,预测,非极大值
  7. 画框,保存结果
# YOLOv5  by Ultralytics, GPL-3.0 license
"""
AttributeError: partially initialized module 'cv2' has no attribute 'gapi_wip_gst_GStreamerPipeline' (most likely due to a circular import)

Run YOLOv5 detection inference on images, videos, directories, globs, YouTube, webcam, streams, etc.

Usage - sources:
    $ python detect.py --weights yolov5s.pt --source 0                               # webcam
                                                     img.jpg                         # image
                                                     vid.mp4                         # video
                                                     screen                          # screenshot
                                                     path/                           # directory
                                                     list.txt                        # list of images
                                                     list.streams                    # list of streams
                                                     'path/*.jpg'                    # glob
                                                     'https://youtu.be/Zgi9g1ksQHc'  # YouTube
                                                     'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP stream

Usage - formats:
    $ python detect.py --weights yolov5s.pt                 # PyTorch
                                 yolov5s.torchscript        # TorchScript
                                 yolov5s.onnx               # ONNX Runtime or OpenCV DNN with --dnn
                                 yolov5s_openvino_model     # OpenVINO
                                 yolov5s.engine             # TensorRT
                                 yolov5s.mlmodel            # CoreML (macOS-only)
                                 yolov5s_saved_model        # TensorFlow SavedModel
                                 yolov5s.pb                 # TensorFlow GraphDef
                                 yolov5s.tflite             # TensorFlow Lite
                                 yolov5s_edgetpu.tflite     # TensorFlow Edge TPU
                                 yolov5s_paddle_model       # PaddlePaddle
"""

import argparse
import os
import platform
import sys
from pathlib import Path

import torch

FILE = Path(__file__).resolve()  #  '/Users/liushuang/Downloads/yolov5-master/detect.py'  当前文件路径
ROOT = FILE.parents[0]           #  '/Users/liushuang/Downloads/yolov5-master'  YOLOv5 root directory 当前文件路径的父目录
if str(ROOT) not in sys.path:    #  模块查询路径
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative   得到相对路径  '.'

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
                           increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode


@smart_inference_mode()
def run(
        weights=ROOT / 'yolov5s.pt',  # model path or triton URL
        source=ROOT / 'data/images',  # file/dir/URL/glob/screen/0(webcam)
        data=ROOT / 'data/coco128.yaml',  # dataset.yaml path
        imgsz=(640, 640),  # inference size (height, width)
        conf_thres=0.25,  # confidence threshold
        iou_thres=0.45,  # NMS IOU threshold
        max_det=1000,  # maximum detections per image
        device='',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
        view_img=False,  # show results
        save_txt=False,  # save results to *.txt
        save_conf=False,  # save confidences in --save-txt labels
        save_crop=False,  # save cropped prediction boxes
        nosave=False,  # do not save images/videos
        classes=None,  # filter by class: --class 0, or --class 0 2 3
        agnostic_nms=False,  # class-agnostic NMS
        augment=False,  # augmented inference
        visualize=False,  # visualize features
        update=False,  # update all models
        project=ROOT / 'runs/detect',  # save results to project/name
        name='exp',  # save results to project/name
        exist_ok=False,  # existing project/name ok, do not increment
        line_thickness=3,  # bounding box thickness (pixels)
        hide_labels=False,  # hide labels
        hide_conf=False,  # hide confidences
        half=False,  # use FP16 half-precision inference
        dnn=False,  # use OpenCV DNN for ONNX inference
        vid_stride=1,  # video frame-rate stride
):
    source = str(source)   #  'data/images'
    save_img = not nosave and not source.endswith('.txt')  # save inference images
    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)  # suffix 后缀 '.jpg'   True  是视频或者图片?
    is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))   # 网址
    webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)  # 摄像头
    screenshot = source.lower().startswith('screen')  # False
    if is_url and is_file:
        source = check_file(source)  # download

    # Directories
    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run ; PosixPath('runs/detect/exp3')
    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

    # Load model
    device = select_device(device)   # YOLOv5  2023-4-15 Python-3.10.10 torch-2.0.0 CPU
    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)  # 选择模型后端框架 YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients, 16.4 GFLOPs
    stride, names, pt = model.stride, model.names, model.pt      # 32, cls_names, 模型是否是pytorch True
    imgsz = check_img_size(imgsz, s=stride)                      # check image size  [640, 640] 图片尺寸是否是32的倍数

    # Dataloader
    bs = 1  # batch_size
    if webcam:   #  摄像头
        view_img = check_imshow(warn=True)
        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
        bs = len(dataset)
    elif screenshot:
        dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
    else:    # 文件夹
        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
    vid_path, vid_writer = [None] * bs, [None] * bs

    # Run inference
    model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup  imgsz = (1, 3, 640, 640)
    seen, windows, dt = 0, [], (Profile(), Profile(), Profile())  # 0,[],(,, )
    for path, im, im0s, vid_cap, s in dataset:  # /Users/liushuang/Downloads/yolov5-master/data/images/bus.jpg;None; image 1/2 /Users/liushuang/Downloads/yolov5-master/data/images/bus.jpg:
        with dt[0]:   # 这个是干啥用的?
            im = torch.from_numpy(im).to(model.device)   # torch.Size([3, 384, 640])
            im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32   float32-->float16
            im /= 255  # 0 - 255 to 0.0 - 1.0
            if len(im.shape) == 3:
                im = im[None]  # expand for batch dim   torch.Size([1, 3, 384, 640])

        # Inference
        with dt[1]:
            visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False    # False
            pred = model(im, augment=augment, visualize=visualize)  # pred[0].shape: torch.Size([1, 15120, 85])  ;  pred[1][0]:torch.Size([1, 3, 48, 80, 85]);;;pred[1][1]:torch.Size([1, 3, 24, 40, 85]);;;pred[1][0]:torch.Size([1, 3, 12, 20, 85]);;;

        # NMS
        with dt[2]:
            pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)  # torch.Size([4, 4(boxes)+1(conf)+1(cls)])

        # Second-stage classifier (optional)
        # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)

        # Process predictions
        for i, det in enumerate(pred):  # per image
            seen += 1   # 记录观测到的物体
            if webcam:  # batch_size >= 1
                p, im0, frame = path[i], im0s[i].copy(), dataset.count
                s += f'{i}: '
            else:
                p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)  # path,(720, 1280, 3),0

            p = Path(p)  # to Path
            save_path = str(save_dir / p.name)  # im.jpg   #'runs/detect/exp3/zidane.jpg'
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # im.txt  #  'runs/detect/exp3/labels/zidane'
            s += '%gx%g ' % im.shape[2:]  # print string   # s:image 2/2 /Users/liushuang/Downloads/yolov5-master/data/images/zidane.jpg  ; im.shape[2:] : '384x640 '
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh   # (720, 1280, 3) -->  tensor([1280,  720, 1280,  720])
            imc = im0.copy() if save_crop else im0  # for save_crop
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))  # 绘图工具
            if len(det):
                # Rescale boxes from img_size to im0 size  把检测到的框映射到原图上
                det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()   # im.shape[2:]:torch.Size([384, 640]) ; im0.shape: (720, 1280, 3)

                # Print results
                for c in det[:, 5].unique():
                    n = (det[:, 5] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results  保存结果
                for *xyxy, conf, cls in reversed(det):
                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh   box在原图上的相对位置
                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
                        with open(f'{txt_path}.txt', 'a') as f:
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')

                    if save_img or save_crop or view_img:  # Add bbox to image   画到原图上
                        c = int(cls)  # integer class
                        label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')  # 'person 0.88'
                        annotator.box_label(xyxy, label, color=colors(c, True))
                    if save_crop:
                        save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)  # (684, 416, 3)

            # Stream results
            im0 = annotator.result()
            if view_img:
                if platform.system() == 'Linux' and p not in windows:
                    windows.append(p)
                    cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)
                    cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)  # 1 millisecond

            # Save results (image with detections)
            if save_img:
                if dataset.mode == 'image':
                    cv2.imwrite(save_path, im0)
                else:  # 'video' or 'stream'
                    if vid_path[i] != save_path:  # new video
                        vid_path[i] = save_path
                        if isinstance(vid_writer[i], cv2.VideoWriter):
                            vid_writer[i].release()  # release previous video writer
                        if vid_cap:  # video
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        else:  # stream
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                        save_path = str(Path(save_path).with_suffix('.mp4'))  # force *.mp4 suffix on results videos
                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
                    vid_writer[i].write(im0)

        # Print time (inference-only)
        LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")

    # Print results
    t = tuple(x.t / seen * 1E3 for x in dt)  # speeds per image  dt是检测耗时,seen记录检测的物体数量
    LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
    if update:
        strip_optimizer(weights[0])  # update model (to fix SourceChangeWarning)


def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path or triton URL')   # 权重
    parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')    # 检测对象
    parser.add_argument('--data', type=str, default=ROOT / 'data/my_data.yaml', help='(optional) dataset.yaml path')      # 数据
    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')      # 置信度阈值
    parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')          # IOU阈值
    parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')   # 最大检测数量/每张图片
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')       # 运算设备
    parser.add_argument('--view-img', action='store_true', help='show results')                     # 展示预测结果
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')            # 保存labels
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')   # 抠图
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')  # 从所有类别中选去感兴趣的个别类
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')              # 数据增强
    parser.add_argument('--visualize', action='store_true', help='visualize features')
    parser.add_argument('--update', action='store_true', help='update all models')                  # 更新
    parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')       # 保存路径
    parser.add_argument('--name', default='exp', help='save results to project/name')                         # 保存路径的子文件夹
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') # 新预测文件
    parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')      # 线条宽度
    parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
    parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')   # 半,精度
    parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')    # 分布式训练
    parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')       # 取祯间隔时长
    opt = parser.parse_args()
    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand  [640]-->[640,640]
    print_args(vars(opt))   # 打印所有的参数
    return opt


def main(opt):
    check_requirements(exclude=('tensorboard', 'thop'))    # 检测requirements.txt里面的包有没有成功安装。
    run(**vars(opt))


if __name__ == '__main__':
    opt = parse_opt()   # 解析命令行参数
    main(opt)

2. 分类

2.1 数据集

2.2 模型

2.3 损失函数

2.4 训练 trian.py

2.5 推理 detec.py

3. 分割

3.1 数据集

3.2 模型

3.3 损失函数

3.4 训练 trian.py

3.5 推理 detec.py


你可能感兴趣的:(Pytorch,计算机视觉,Python,YOLO,深度学习,计算机视觉)