【博主使用的Python版本:3.9.7】
【博主使用的 OpenCV版本:4.5.0】
本文所使用的资料已上传到百度网盘【https://pan.baidu.com/s/1-OyW8kGbfV58bO4q3GK0tA?pwd=j7u9】,提取码:j7u9。
OpenCV的全称是:Open Source Computer Vision Library, OpenCV是一个基于Apache2.0许可(开源)发行的跨平台计算机视觉和机器学习软件库, 其采用 C/C++ 编写,同时提供了Python、Ruby、MATLAB等语言的接口,实现了图像处理和计算机视觉方面的很多通用算法。其主要关注的是实时应用,同时,OpenCV 的另一个目标是构建一个简单易用的计算机视觉框架,以帮助开发人员更便捷地设计更复杂的计算机视觉相关的应用程序。可以从http://opencv.org获取。
我们要做的事是搭建一个【多目标跟踪】的简单的框架,你可以跟随我的步骤在PyCharm中一步步地把代码填进去,也可以直接复制完整代码,完整代码在本文最底部。
在开始之前,我们需要引入的库:
如果你没有以上的库,请自行安装。
import cv2
import numpy as np
from object_detection import ObjectDetection
import math
object_detection.py代码如下:
import cv2
import numpy as np
class ObjectDetection:
def __init__(self, weights_path="dnn_model/yolov4.weights", cfg_path="dnn_model/yolov4.cfg"):
print("Loading Object Detection")
print("Running opencv dnn with YOLOv4")
self.nmsThreshold = 0.4
self.confThreshold = 0.5
self.image_size = 608
# Load Network
net = cv2.dnn.readNet(weights_path, cfg_path)
# Enable GPU CUDA
net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
self.model = cv2.dnn_DetectionModel(net)
self.classes = []
self.load_class_names()
self.colors = np.random.uniform(0, 255, size=(80, 3))
self.model.setInputParams(size=(self.image_size, self.image_size), scale=1/255)
def load_class_names(self, classes_path="dnn_model/classes.txt"):
with open(classes_path, "r") as file_object:
for class_name in file_object.readlines():
class_name = class_name.strip()
self.classes.append(class_name)
self.colors = np.random.uniform(0, 255, size=(80, 3))
return self.classes
def detect(self, frame):
return self.model.detect(frame, nmsThreshold=self.nmsThreshold, confThreshold=self.confThreshold)
在开始之前首先进行第一个测试,以确保后面所做的都是正确的。
cap = cv2.VideoCapture("los_angeles.mp4") # 加载视频片段
_, frame = cap.read() # 从视频中获取帧
cv2.imshow("Frame", frame) # 显示帧
cv2.waitKey(0) # 保持窗口打开
运行结果如下:
现在我们成功获取并加载显示出了视频的第一帧,这是一个好的开始。有了第一帧,接下来,继续加载整个视频。
什么是视频?
通俗的来讲,视频就是一个接一个的大量图像。如果你检查相机的规格,例如相机可以以30fps的速度录制,这就意味着相机每秒记录30帧,说明在一秒钟内有30张图像。
所以现在我们将获取帧的步骤放入循环中,在循环内一个接一个地获取帧。注:这里要判断视频是否播放完,即判断是否存在帧,如果不存在帧,则退出循环。
cap = cv2.VideoCapture("los_angeles.mp4") # 加载视频片段
while True: # 获取连续的帧
ret, frame = cap.read() # 从视频中获取帧
if not ret: # 是否存在帧
break # 如果不存在帧,则退出
cv2.imshow("Frame", frame) # 显示帧
key = cv2.waitKey(1) # 1:每帧延迟1ms
if key == 27: # 注:ESC键
break
cap.release() # 释放视频文件
cv2.destroyAllWindows() # 关闭所有窗口
运行结果如下(按下ESC键,退出循环):
(这里只能上传5M以内的图,所以上传的动图压缩了,略模糊)
已经确保了视频每帧都能成功获取,现在调用object_detection.py中目标检测函数,获取每帧中的包含的目标信息,代码如下:
cap = cv2.VideoCapture("los_angeles.mp4") # 加载视频片段
while True: # 获取连续的帧
ret, frame = cap.read() # 从视频中获取帧
if not ret: # 是否存在帧
break # 如果不存在帧,则退出
(class_ids, scores, boxes) = od.detect(frame) # 当前帧中检测目标中的信息
# class_id:what object is (car / track / person)
# score: how confident is about the detection and
# box: bounding box of the location of each object
for box in boxes: # 不区分类别,只画框
print(box) # 打印框,确保提取目标正确
cv2.imshow("Frame", frame) # 显示帧
key = cv2.waitKey(1) # 1:每帧延迟1ms
if key == 27: # 注:ESC键
break
cap.release() # 释放视频文件
cv2.destroyAllWindows() # 关闭所有窗口
运行结果如下(这里只放了部分):
[505 802 133 178]
[376 683 122 118]
[1671 603 159 64]
[727 605 68 87]
[972 610 92 74]
[898 508 61 52]
[826 531 61 69]
[592 457 40 32]
[861 457 39 32]
[1214 880 244 199]
[735 445 36 37]
[1100 424 37 25]
[1835 560 85 91]
以上结果中,每一行代表一个目标的信息,其中前两个数字为目标框的左上角点坐标(x,y),第三个数字为目标框的宽度,第四个数字为目标框的高度。知道这些信息后,我们就可以在每帧中画出矩形框,框出目标。代码如下:
(class_ids, scores, boxes) = od.detect(frame) # 当前帧中检测目标中的信息
# class_id:what object is (car / track / person)
# score: how confident is about the detection and
# box: bounding box of the location of each object
for box in boxes: # 不区分类别,只画框
(x, y, w, h) = box # 矩形的左上角坐标,及宽度和高度
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) # 在当前帧中根据box绘制矩形
运行结果:
接下来可以给每个跟踪目标分配ID,确保跟踪的为同一个目标。
for box in boxes: # 不区分类别,只画框
# print(box) # 打印框,确保提取目标正确
(x, y, w, h) = box # 矩形的左上角坐标,及宽度和高度
cx = int((x + x + w) / 2) # 中心点的x坐标
cy = int((y + y + h) / 2) # 中心点的y坐标
center_points_cur_frame.append((cx, cy)) # 添加新的中心点到数组中
# print("FRAME N°", count, " ", x, y, w, h) # 打印每一帧中的框
# cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1) # 在当前帧中以框的中心点为中心画圆,半径为5,红色,用所有颜色填充圆圈
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) # 在当前帧中根据box绘制矩形
# only at the beginning we compare previous and current frame
if count <= 2:
for pt in center_points_cur_frame:
# cv2.circle(frame, pt, 5, (0, 0, 255), -1) # 画出所有中心点
for pt2 in center_points_prev_frame:
distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1]) # 当前帧与前一帧的目标中心点的距离
if distance < 20: # 距离小于20像素
tracking_objects[track_id] = pt # 当前帧中的目标中心
track_id += 1
else:
tracking_objects_copy = tracking_objects.copy() # 建立跟踪目标字典副本
center_points_cur_frame_copy = center_points_cur_frame.copy()
for object_id, pt2 in tracking_objects.copy().items(): # 首先遍历新的数组
object_exists = False # 首先假设当前帧中不存在目标
for pt in center_points_cur_frame: # 遍历当前帧中的目标
distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1]) # 当前帧与前一帧的目标中心点的距离
# Update IDs position
if distance < 20:
tracking_objects[object_id] = pt # 更新目标位置
object_exists = True # 目标存在
if pt in center_points_cur_frame:
center_points_cur_frame.remove(pt)
continue # 继续下一帧
# Remove IDs lost
if not object_exists: # 目标不存在
tracking_objects.pop(object_id) # 则移除目标ID
# Add new IDs found
for pt in center_points_cur_frame:
tracking_objects[track_id] = pt
track_id += 1
for object_id, pt in tracking_objects.items():
cv2.circle(frame, pt, 5, (0, 0, 255), -1) # 画出所有中心点
# 在帧中显示目标id,文本位置,文本字体类型,字体大小,颜色,粗细(注:需>=0)
cv2.putText(frame, str(object_id), (pt[0], pt[1] - 7), 0, 1, (0, 0, 255), 0)
运行结果:
object_tracking.py完整代码如下:
import cv2
import numpy as np
from object_detection import ObjectDetection
import math
# Initialize Object Detection
od = ObjectDetection() # 加载目标
cap = cv2.VideoCapture("los_angeles.mp4") # 加载视频片段
frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) # 通过属性获取帧数
# Initialize count
count = 0 # 用于计算视频的实际帧数
# center_points = [] # 用于储存所有的中心点
center_points_prev_frame = [] # 空数组用于储存第一帧前的空帧
tracking_objects = {} # 用于储存跟踪目标
track_id = 0 # 跟踪目标初始序号
print(cv2.__version__)
while True: # 获取连续的帧
ret, frame = cap.read() # 从视频中获取帧
count += 1
if not ret: # 是否存在帧
break # 如果不存在帧,则退出
# point current frame
center_points_cur_frame = [] # 用于储存当前帧的目标的中心点
# Detect objects on frame
(class_ids, scores, boxes) = od.detect(frame) # 当前帧中检测目标中的信息
# class_id:what object is (car / track / person)
# score: how confident is about the detection and
# box: bounding box of the location of each object
for box in boxes: # 不区分类别,只画框
# print(box) # 打印框,确保提取目标正确
(x, y, w, h) = box # 矩形的左上角坐标,及宽度和高度
cx = int((x + x + w) / 2) # 中心点的x坐标
cy = int((y + y + h) / 2) # 中心点的y坐标
center_points_cur_frame.append((cx, cy)) # 添加新的中心点到数组中
# print("FRAME N°", count, " ", x, y, w, h) # 打印每一帧中的框
# cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1) # 在当前帧中以框的中心点为中心画圆,半径为5,红色,用所有颜色填充圆圈
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) # 在当前帧中根据box绘制矩形
# only at the beginning we compare previous and current frame
if count <= 2:
for pt in center_points_cur_frame:
# cv2.circle(frame, pt, 5, (0, 0, 255), -1) # 画出所有中心点
for pt2 in center_points_prev_frame:
distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1]) # 当前帧与前一帧的目标中心点的距离
if distance < 20: # 距离小于20像素
tracking_objects[track_id] = pt # 当前帧中的目标中心
track_id += 1
else:
tracking_objects_copy = tracking_objects.copy() # 建立跟踪目标字典副本
center_points_cur_frame_copy = center_points_cur_frame.copy()
for object_id, pt2 in tracking_objects.copy().items(): # 首先遍历新的数组
object_exists = False # 首先假设当前帧中不存在目标
for pt in center_points_cur_frame: # 遍历当前帧中的目标
distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1]) # 当前帧与前一帧的目标中心点的距离
# Update IDs position
if distance < 20:
tracking_objects[object_id] = pt # 更新目标位置
object_exists = True # 目标存在
if pt in center_points_cur_frame:
center_points_cur_frame.remove(pt)
continue # 继续下一帧
# Remove IDs lost
if not object_exists: # 目标不存在
tracking_objects.pop(object_id) # 则移除目标ID
# Add new IDs found
for pt in center_points_cur_frame:
tracking_objects[track_id] = pt
track_id += 1
for object_id, pt in tracking_objects.items():
cv2.circle(frame, pt, 5, (0, 0, 255), -1) # 画出所有中心点
# 在帧中显示目标id,文本位置,文本字体类型,字体大小,颜色,粗细(注:需>=0)
cv2.putText(frame, str(object_id), (pt[0], pt[1] - 7), 0, 1, (0, 0, 255), 0)
print("Tracking Objects :")
print(tracking_objects)
print("CUR FRAME LEFT PTS :")
print(center_points_cur_frame)
# print("PREV FRAME :")
# print(center_points_prev_frame)
cv2.imshow("Frame", frame) # 显示帧
# Make a copy of the points
center_points_prev_frame = center_points_cur_frame.copy()
key = cv2.waitKey(0) # 1:每帧延迟1ms 0:保持当前帧不动
if key == 27: # 注:ESC键
break
cap.release() # 释放视频文件
cv2.destroyAllWindows() # 关闭所有窗口
print("通过属性获取的视频帧数 :", frames)
print("实际遍历整个视频的帧数 :", count-1)