YOLOv5_DeepSORT_Pytorch训练自己的多目标跟踪模型

目录

  • 1 准备
  • 2 训练目标检测模型
  • 3 训练目标跟踪模型
    • 3.1 准备数据
    • 3.2 训练目标跟踪模型
    • 3.3 测试目标跟踪模型
  • 4 训练行人ReID数据集
    • 4.1 数据集介绍
    • 4.2 数据集处理

1 准备

  1. 环境配置:https://blog.csdn.net/weixin_50008473/article/details/115250986?spm=1001.2014.3001.5501
  2. YOLOv5_DeepSORT_Pytorch代码地址:
    https://github.com/mikel-brostrom/Yolov5_DeepSort_Pytorch
  3. DeepSORT论文:https://arxiv.org/pdf/1703.07402.pdf

2 训练目标检测模型

训练自己的目标检测模型参考:https://blog.csdn.net/weixin_50008473/article/details/115331067?spm=1001.2014.3001.5501

(1)将使用的yolov5对应版本modelsutils文件夹放入Yolov5_DeepSort_Pytorch文件夹下
(2)将训练好的检测模型放入weights文件夹
目录结构如下图:
YOLOv5_DeepSORT_Pytorch训练自己的多目标跟踪模型_第1张图片

3 训练目标跟踪模型

3.1 准备数据

Yolov5_DeepSort_Pytorch\deep_sort_pytorch\deep_sort\deep\checkpoint\ckpt.t7是DeepSort在行人ReID数据集训练出来,用于提取行人的外观特征的权重。

prepare.py抠出标注的真实框中的数据用来训练自己的跟踪模型如下:

import os
import cv2
import numpy as np 
import xml.etree.ElementTree as ET
import xml.dom.minidom
import argparse
 
 
def main():
    img_path = 'data/images/'
    anno_path = 'data/annotations/'
    cut_path = 'data/train/'
    if not os.path.exists(cut_path):
        os.makedirs(cut_path)
    imagelist = os.listdir(img_path)
    for image in imagelist:
        image_pre, ext = os.path.splitext(image)
        img_file = img_path + image
        img = cv2.imread(img_file)
        xml_file = anno_path + image_pre + '.xml'
 
        tree = ET.parse(xml_file)
        root = tree.getroot()
        obj_i = 0
        for obj in root.iter('object'):
            obj_i += 1
            cls = obj.find('name').text
            xmlbox = obj.find('bndbox')
            b = [int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)),
                 int(float(xmlbox.find('ymax').text))]
            img_cut = img[b[1]:b[3], b[0]:b[2], :]
            path = os.path.join(cut_path, cls)
            mkdirlambda = lambda x: os.makedirs(x) if not os.path.exists(x) else True
            mkdirlambda(path)
            cv2.imwrite(os.path.join(cut_path, cls, '{}_{:0>2d}.jpg'.format(image_pre, obj_i)), img_cut)
            print("&&&&")
 
 
if __name__ == '__main__':
    main()

目录如下:
YOLOv5_DeepSORT_Pytorch训练自己的多目标跟踪模型_第2张图片

3.2 训练目标跟踪模型

1.修改train.py中对训练数据集的预处理部分
YOLOv5_DeepSORT_Pytorch训练自己的多目标跟踪模型_第3张图片
2.将model.py(位于Yolov5_DeepSort_Pytorch\deep_sort_pytorch\deep_sort\deep下)中的num_classes修改为自己的类别数(此处为6)
在这里插入图片描述
3.在train.py中默认的epoch为40,可根据需要进行修改
YOLOv5_DeepSORT_Pytorch训练自己的多目标跟踪模型_第4张图片
执行train.py开始训练:
YOLOv5_DeepSORT_Pytorch训练自己的多目标跟踪模型_第5张图片

3.3 测试目标跟踪模型

1.修改加载检测模型和跟踪模型的路径,执行track.py
在这里插入图片描述
2.若要跟踪多个类别:

python track.py --classes 1 2 

4 训练行人ReID数据集

4.1 数据集介绍

—bounding_box_test (测试集)
—bounding_box_train (训练集)
—gt_bbox (手绘边框)
—gt_query (包含实际标注)
—query (共750个身份)

图片命名如:0107_c1s1_018076_03.jpg
0107:表示每个人的标签编号,从0001到1501,共有1501个人
c1 表示第一个摄像头,共有6个摄像头
s1 表示第一个录像片段,每个摄像机都有多个录像片段
018076表示 c1s1 的第018076帧图片,视频帧率fps为25
03表示 c1s1_018076这一帧上的第3个检测框

Market-1501数据集链接:https://pan.baidu.com/s/1P3iqzKhbG3vugZGy2QDVCA
提取码:8h5k

4.2 数据集处理

(1)下载数据集,解压至Yolov5_DeepSort_Pytorch\deep_sort_pytorch\deep_sort\deep
处理参考:https://github.com/layumi/Person_reID_baseline_pytorch/blob/master/prepare.py
(2)处理后Market-1501-v15.09.15文件夹下生成pytorch文件夹
(3)将新生成pytorch文件夹下的traintest拷贝到data文件夹
(4)同上执行train.py完成训练

你可能感兴趣的:(目标检测,目标跟踪,pytorch,深度学习)