【yolov4目标检测】(2) 多目标跟踪,案例:车辆行人的跟踪和计数,附python完整代码和数据集

各位同学好,今天和大家分享一下如何使用 YOLOv4 目标检测完成对道路上的车辆、行人的检测、跟踪和计数。先放张图看效果。

绿框代表检测出的目标,粉色点代表每个检测框的中心点,红色数字用于跟踪该目标。当目标在画面上消失时,红色索引自动消失,有新目标出现时,生成新的索引。

由于篇幅原因,yolov4的训练过程就不讲了,在下几个章节中再写。本节使用已经训练好的模型检测目标。本节的yolo模型,python代码和数据集有需要的自取。

链接:https://pan.baidu.com/s/1mhmzOWGS6KGYhEul5Lej-w 

提取码:1234


1. 读入视频资源

cv2.VideoCapture() 来导入指定路径的视频,如果有需要的话可以使用 cv2.VideoCapture(1) 调用外接摄像头,用于实时的目标检测。cv2.read() 读取视频中每一帧的图像,返回图像是否读取成功,以及成功读入的图像。cv2.namedWindow('name', 0) 用来自定义调整图像显示窗口的大小,窗口的名称要和 cv2.imshow() 中的窗口名称相同。

import cv2
import numpy as np
import time

#(1)导入视频
filepath = 'C:\\GameDownload\\Deep Learning\\car.flv'
cap = cv2.VideoCapture(filepath)

pTime = 0  # 设置第一帧开始处理的起始时间

#(2)处理每一帧图像
while True:
    
    # 接收图片是否导入成功、帧图像
    success, img = cap.read()

    # 查看FPS
    cTime = time.time() #处理完一帧图像的时间
    fps = 1/(cTime-pTime)
    pTime = cTime  #重置起始时间
    
    # 在视频上显示fps信息,先转换成整数再变成字符串形式,文本显示坐标,文本字体,文本大小
    cv2.putText(img, str(int(fps)), (70,50), cv2.FONT_HERSHEY_PLAIN, 3, (255,0,0), 3)  
    
    #(3)显示图像,输入窗口名及图像数据
    cv2.namedWindow("img", 0)  # 窗口大小可调整
    cv2.imshow('img', img)    
    if cv2.waitKey(20) & 0xFF==27:  #每帧滞留20毫秒后消失,ESC键退出
        break

# 释放视频资源
cap.release()
cv2.destroyAllWindows()

原始视频图像如下,左上角的33代表视频的FPS值


2. 检测目标

导入我放在压缩包里面的 object_detection.py 文件中的 ObjectDetection() 目标检测方法。该模型训练时使用的是COCO数据集,能有效检测80种目标。这里只需要用来检测行人、汽车、摩托在、自行车等。

od = ObjectDetection() 将定义的检测类传给变量od。使用 od.detect(),调用目标检测方法。返回:class_ids图像属于哪个分类;scores图像属于这个分类的概率;boxes目标检测的识别框信息,包括检测框左上角的坐标,以及检测框的宽和高。

通过左上角坐标 (x, y)右下角坐标 (x+w, y+h) 计算检测框的中心坐标,并将其保存,接下来的目标跟踪需要重点研究中心坐标。

有兴趣的同学也可以把这个目标属于哪个分类以及它的概率显示在图像上

在上述代码中补充:

import cv2
import numpy as np
import time
from object_detection import ObjectDetection  # 导入定义好的目标检测方法

#(1)获取目标检测方法
od = ObjectDetection()

#(2)导入视频
filepath = 'C:\\GameDownload\\Deep Learning\\car.flv'
cap = cv2.VideoCapture(filepath)

pTime = 0  # 设置第一帧开始处理的起始时间

count = 0  # 记录帧数

center_points_prev = []  # 存放前一帧检测框的中心点

#(3)处理每一帧图像
while True:
    
    count += 1  # 记录当前是第几帧
    print('------------------------')
    print('NUM:', count)
    
    # 接收图片是否导入成功、帧图像
    success, img = cap.read()
    
    # 如果读入不到图像就退出
    if success == False:
        break
    
    center_points_current = []  # 储存当前帧的所有目标的中心点坐标

    #(4)目标检测
    # 将每一帧的图像传给目标检测方法
    # 返回class_ids图像属于哪个分类;scores图像属于某个分类的概率;boxes目标检测的识别框
    class_ids, scores, boxes  = od.detect(img)
    
    # 绘制检测框,boxes中包含每个目标检测框的左上坐标和每个框的宽、高
    for box in boxes:
        (x, y, w, h) = box
        
        # 获取每一个框的中心点坐标,像素坐标是整数
        cx, cy = int((x+x+w)/2), int((y+y+h)/2) 
        
        # 存放每一帧的所有框的中心点坐标
        center_points_current.append((cx,cy))
        
        # 绘制矩形框。传入帧图像,框的左上和右下坐标,框颜色,框的粗细
        cv2.rectangle(img, (x,y), (x+w,y+h), (0,255,0), 2)
    
    # 显示所有检测框的中心点,pt代表所有中心点坐标
    for pt in center_points_current:
        cv2.circle(img, pt, 5, (0,0,255), -1)
        
    # 打印前一帧的中心点坐标
    print('prevent center points')
    print(center_points_prev)
    
    # 打印当前帧的中心点坐标
    print('current center points')
    print(center_points_current)       
    
    #(5)图像显示
    # 查看FPS
    cTime = time.time() #处理完一帧图像的时间
    fps = 1/(cTime-pTime)
    pTime = cTime  #重置起始时间
    
    # 在视频上显示fps信息,先转换成整数再变成字符串形式,文本显示坐标,文本字体,文本大小
    # cv2.putText(img, str(int(fps)), (70,50), cv2.FONT_HERSHEY_PLAIN, 3, (255,0,0), 3)  
    
    # 显示图像,输入窗口名及图像数据
    cv2.imshow('img', img)    
    
    # 复制当前帧的中心点坐标
    center_points_prev = center_points_current.copy()

    # 每帧滞留20毫秒后消失,ESC键退出
    if cv2.waitKey(0) & 0xFF==27:  # 设置为0代表只显示当前帧
        break

# 释放视频资源
cap.release()
cv2.destroyAllWindows()

目标检测结果如下图所示。

打印前后两帧中心点坐标。下面是第7帧和第8帧的前后两帧的各个检测到的目标的中心点坐标。

NUM: 7
prevent center points
[(1670, 545), (1568, 531), (1395, 430), (1326, 402), (586, 548), (768, 467), (1231, 847), (1102, 628), (331, 735), (851, 366), (715, 367), (1059, 338), (1083, 364), (1181, 404), (826, 323), (995, 316), (938, 324), (851, 292), (1023, 288), (171, 630), (1107, 473), (653, 290)]
current center points
[(1671, 545), (1563, 532), (1395, 429), (1326, 402), (1228, 840), (767, 470), (582, 552), (1102, 623), (321, 748), (850, 367), (714, 367), (1058, 338), (1181, 403), (1083, 363), (994, 316), (825, 324), (851, 291), (937, 324), (1023, 289), (172, 630), (1107, 470), (653, 290)]
------------------------
NUM: 8
prevent center points
[(1671, 545), (1563, 532), (1395, 429), (1326, 402), (1228, 840), (767, 470), (582, 552), (1102, 623), (321, 748), (850, 367), (714, 367), (1058, 338), (1181, 403), (1083, 363), (994, 316), (825, 324), (851, 291), (937, 324), (1023, 289), (172, 630), (1107, 470), (653, 290)]
current center points
[(1562, 530), (1672, 545), (1326, 402), (1394, 429), (765, 471), (1099, 620), (579, 557), (1225, 832), (312, 757), (1057, 337), (1179, 401), (713, 368), (1082, 362), (849, 368), (993, 315), (825, 324), (850, 292), (937, 324), (1023, 289), (878, 261), (170, 630), (1107, 470), (653, 290)]

3. 追踪目标

目标追踪的流程是,(1)检测到目标的存在,并给这个目标加上标记(2)比较前后两帧图像的检测目标的中心点之间的距离,如果小于指定的值就认为是相同的目标,标记不变;(3)当目标在画面上消失后,删除目标的标记(4)在某一帧图像中有新目标出现时,在检测到的目标中先给上一帧确定是目标的物体更新中心点坐标,再给剩余的检测目标(即新目标)添加标记

我们只需要比较前两帧图像(count<=2)目标中心点之间的距离,用来确定初始有哪些目标,使用track_id来标记这个目标。当我们确定中心点之间的距离小于多少时判定为目标时,要注意在图像上,离相机越近的车辆速度,在画面上的速度越快,两帧之间的距离就变得较大,这时要合理选择阈值,比较前后两帧之间中心点的距离。

在确定 count>2 时的中心点距离时,要比较当前帧图像检测到的目标的中心点坐标上一帧中追踪的目标的中心点坐标。如果坐标距离小于阈值,那么就证明前后两帧是同一个物体只需要更新上一帧中已追踪目标的中心点坐标即可,不需要对它附新标记。

当把已追踪目标更新完之后,当前帧中剩下的检测出是目标的物体就是新出现的目标,需要赋予新的标签track_idtrack_id每次赋值完之后+1,起计数作用。

如果跟踪的目标消失了,即当前帧检测出的目标和上一帧追踪的目标之间的距离相差大于指定阈值时,表示目标消失,将它的标签从追踪对象的字典track_objects中删除删除。

在上述的代码中补充。

import cv2
import numpy as np
import math
from object_detection import ObjectDetection  # 导入定义好的目标检测方法

#(1)获取目标检测方法
od = ObjectDetection()

#(2)导入视频
filepath = 'C:\\GameDownload\\Deep Learning\\car.flv'
cap = cv2.VideoCapture(filepath)

count = 0  # 记录帧数

center_points_prev = []  # 存放前一帧检测框的中心点

track_objects = {}  # 存放需要追踪的对象

track_id = 0  # 记录追踪对象的索引,把

#(3)处理每一帧图像
while True:
    
    count += 1  # 记录当前是第几帧
    print('------------------------')
    print('NUM:', count)
    
    # 接收图片是否导入成功、帧图像
    success, img = cap.read()
    
    # 如果读入不到图像就退出
    if success == False:
        break
    
    center_points_current = []  # 储存当前帧的所有目标的中心点坐标
    
    #(4)目标检测
    # 将每一帧的图像传给目标检测方法
    # 返回class_ids图像属于哪个分类;scores图像属于某个分类的概率;boxes目标检测的识别框
    class_ids, scores, boxes  = od.detect(img)
    
    # 绘制检测框,boxes中包含每个目标检测框的左上坐标和每个框的宽、高
    for box in boxes:
        (x, y, w, h) = box
        
        # 获取每一个框的中心点坐标,像素坐标是整数
        cx, cy = int((x+x+w)/2), int((y+y+h)/2) 
        
        # 存放每一帧的所有框的中心点坐标
        center_points_current.append((cx,cy))
        
        # 绘制矩形框。传入帧图像,框的左上和右下坐标,框颜色,框的粗细
        cv2.rectangle(img, (x,y), (x+w,y+h), (0,255,0), 2)

    #(5)目标追踪
    # 内容:1.追踪目标,2.删除消失了的目标的标记,3.有新目标出现时添加标记
    
    # 只在前两帧图像中比较前后两帧检测到的物体的中心点的距离
    if count <= 2:
        # 当前后两帧的同一目标的中心点移动距离小于规定值,认为是同一物体,对其追踪
        for pt1 in center_points_current:  # 当前帧的中心点
            for pt2 in center_points_prev:  # 前一帧的中心点
                
                # 计算距离,勾股定理
                distance = math.hypot(pt2[0]-pt1[0], pt2[1]-pt1[1])
                
                # 如果距离小于20各像素则认为是同一个目标
                if distance < 20:
                    
                    # 记录当前目标的索引及其当前帧中心点坐标
                    track_objects[track_id] = pt1
                    track_id += 1
    
    
    # 在后续的帧中比较的是,当前帧检测物体的中心点与被追踪目标中心点之间的距离
    else:
        # 由于在循环过程中不能删除字典的元素,先复制一份
        track_objects_new = track_objects.copy()
        
        # 被追踪目标的中心点坐标
        for object_id, pt2 in track_objects_new.items(): 
            
            # 假设在当前帧中,我们在上一帧中跟踪的对象不存在了
            object_exist = False  # 当目标在屏幕上消失后,将其对应的标记消除
            
            # 当前帧检测到的物体的中心点
            for pt1 in center_points_current:  
            
                # 计算两者间的距离
                distance = math.hypot(pt2[0]-pt1[0], pt2[1]-pt1[1])
                
                # 如果两者之间的像素距离小于20, 那么就认为是同一个目标
                if distance < 20:
                    
                    # 更新被追踪目标的前一帧的中心点坐标,等于当前帧中的中心点坐标
                    track_objects[object_id] = pt1
                    
                    # 距离小于20,证明在当前帧中,检测的目标还存在
                    object_exist = True
                    
                    # 在当前帧所有已检测目标的中心点坐标中,删除已经更新过的中心点坐标
                    # 已检测目标中剩余的坐标就是新出现的目标,需要添加标记
                    center_points_current.remove(pt1)
                    continue
                    
            # 如果追踪的对象消失了,删除它的标记
            if object_exist == False:
                track_objects.pop(object_id)
                
        #(6)添加新目标
        for pt in center_points_current:  # 删除更新坐标后剩余的检测到的坐标点
            
            # 给新出现的目标加上标记
            track_objects[track_id] = pt
            track_id += 1
                
    #(7)显示出每一帧需要追踪的对象
    for object_id, pt in track_objects.items():
        
        # 在追踪目标的中心点画圈
        cv2.circle(img, pt, 5, (255,0,255), -1)
        # 显示该目标的id
        cv2.putText(img, str(object_id), (pt[0], pt[1]-5), 0, 1.5, (0,0,255),3)
   
    # 打印目标的坐标
    print('tracking objects')
    print(track_objects)
        
    # 打印前一帧的中心点坐标
    print('prevent center points')
    print(center_points_prev)
    
    # 打印当前帧的中心点坐标
    print('current center points')
    print(center_points_current)

    #(9)显示图像,输入窗口名及图像数据
    cv2.namedWindow("img", 0)  #调节显示窗口大小
    cv2.imshow('img', img)    
    
    # 复制当前帧的中心点坐标
    center_points_prev = center_points_current.copy()

    # 每帧滞留20毫秒后消失,ESC键退出
    if cv2.waitKey(0) & 0xFF==27:  # 设置为0代表只显示当前帧
        break

# 释放视频资源
cap.release()
cv2.destroyAllWindows()

检测结果图像如下。

打印每一帧的检测结果的中心点坐标信息如下:

------------------------
NUM: 48
tracking objects
{0: (1699, 572), 1: (1528, 520), 2: (1381, 435), 5: (692, 692), 6: (1118, 658), 7: (1077, 522), 9: (822, 432), 10: (659, 422), 13: (1132, 364), 14: (792, 366), 16: (922, 309), 17: (819, 310), 19: (162, 639), 20: (1076, 416), 43: (1339, 866), 58: (931, 775), 59: (1308, 406), 62: (1790, 1034), 63: (331, 888), 66: (643, 296)}
prevent center points
[(970, 309), (643, 296)]
current center points
[]
------------------------
NUM: 49
tracking objects
{0: (1700, 574), 1: (1528, 521), 2: (1381, 435), 5: (688, 704), 6: (1115, 656), 7: (1078, 521), 9: (822, 434), 10: (658, 422), 13: (1130, 363), 14: (791, 367), 17: (820, 316), 19: (162, 639), 20: (1076, 415), 43: (1334, 857), 58: (930, 763), 59: (1308, 407), 62: (1782, 1026), 66: (642, 296), 67: (317, 903), 68: (1001, 286), 69: (1027, 324)}
prevent center points
[]
current center points
[(317, 903), (1001, 286), (1027, 324)]

你可能感兴趣的:(yolo目标检测,目标跟踪,目标检测,python,yolo,opencv)