https://github.com/qqwweee/keras-yolo3
https://github.com/Qidian213/deep_sort_yolov3
keras_yolov3:yolov3.weights转换为yolo.h5
基于YOLOv3和deep_sort的多目标跟踪
首先需要搞清楚我们的工程文件夹结构,文件结构如下:
工程路径目录
再次一试目标跟踪
1. deep_sort_yolov3
2. keras_yolo3
3. deep_sort
4. tools
5. model_data
6. 多人目标跟踪.py
7. yolo.py
本博客采用keras框架,需要先将yoloV3版本的权重文件和配置文件转化成keras版本的.h5文件。
本博客采用yoloV3版本的权重和配置文件,首先下载keras版本的yolo代码,执行
cd 再试一次目标跟踪
git clone https://github.com/qqwweee/keras-yolo3
然后便会在再试一次目标跟踪文件夹下面生成keras-yolo3文件夹。
3.1 其中权重的下载链接为yolov3.weights,并将该下载的yolov3.weights文件放到keras-yolo3文件夹下面
3.2 配置文件本身便位于keras_yolo3/yolov3.cfg
3.3 在keras-yolo3文件夹下面,执行
python convert.py yolov3.cfg yolov3.weights model_data/yolo.h5
则会在当前keras-yolo3文件夹下面的model_data文件夹下面生成一个yolo.h5文件,一定要看生成的信息提示,这一步要注意yolo.h5一定要正确生成,否则后面运行肯定会报错!!!
3.4 将生成的yolo.h5文件放到“再试一次目标跟踪”文件夹下面的model_data文件夹下面,执行
cp -r keras-yolo3/model_data/yolo.h5 ./model_data/
4.1 接下来比较重要的一步是需要自己生成mars-small128.pb,这个是表观特征,跟踪时做特征匹配用的(也因为TensorFlow版本问题,需要自己生成)。这一步参考了上面提到的链接基于YOLOv3和deep_sort的多目标跟踪。具体自己生成mars-small128.pb文件的方法需要用到https://github.com/nwojke/deep_sort。下载该链接,放到“再次一试目标跟踪”文件夹下面,执行
cd deep_sort/tools
切换到deep_sort/tools文件夹下面,执行
python tools/freeze_model.py
这样便会生成mars-small128.pb,将生成的文件放到“再次一试目标跟踪”文件夹下面。
4.2 删除deep_sort文件夹,首先cd到“再次一试目标跟踪”路径下面,执行
rm -rf deep_sort
4.3 如果您无法生成上面的mars-small128.pb,则可以使用该github链接https://github.com/Qidian213/deep_sort_yolov3,该文件夹下面有一个model_data文件夹,文件夹下面有一个mars-small128.pb,您可以将这个拷贝过来,直接放到“再次一试目标跟踪”文件夹下面。
执行“再次一试目标跟踪”文件下面的“多人目标跟踪.py”文件,在执行该文件之前需要设置好文件的依赖项,步骤如下
5.1 执行
cd 再试一次目标跟踪
cp -r deep_sort_yolov3/deep_sort/ .
5.2 执行
cd 再试一次目标跟踪
cp -r deep_sort_yolov3/tools/ .
5.3 执行
cd 再试一次目标跟踪
cp -r deep_sort_yolov3/model_data/voc_classes.txt model_data/
cp -r deep_sort_yolov3/model_data/yolo_anchors.txt model_data/
5.4 “再次一试目标跟踪”yolo.py文件
yolo.py文件的内容如下:
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""
Run a YOLO_v3 style detection model on test images.
"""
import colorsys
import os
import random
import numpy as np
from keras import backend as K
from keras.models import load_model
from yolo3.model import yolo_eval
from yolo3.utils import letterbox_image
class YOLO(object):
def __init__(self):
self.model_path = 'model_data/yolo.h5'
self.anchors_path = 'model_data/yolo_anchors.txt'
self.classes_path = 'model_data/coco_classes.txt'
self.score = 0.5
self.iou = 0.5
self.class_names = self._get_class()
self.anchors = self._get_anchors()
self.sess = K.get_session()
self.model_image_size = (416, 416) # fixed size or (None, None)
self.is_fixed_size = self.model_image_size != (None, None)
self.boxes, self.scores, self.classes = self.generate()
def _get_class(self):
classes_path = os.path.expanduser(self.classes_path)
with open(classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names
def _get_anchors(self):
anchors_path = os.path.expanduser(self.anchors_path)
with open(anchors_path) as f:
anchors = f.readline()
anchors = [float(x) for x in anchors.split(',')]
anchors = np.array(anchors).reshape(-1, 2)
return anchors
def generate(self):
model_path = os.path.expanduser(self.model_path)
assert model_path.endswith('.h5'), 'Keras model must be a .h5 file.'
self.yolo_model = load_model(model_path, compile=False)
print('{} model, anchors, and classes loaded.'.format(model_path))
# Generate colors for drawing bounding boxes.
hsv_tuples = [(x / len(self.class_names), 1., 1.)
for x in range(len(self.class_names))]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
self.colors))
random.seed(10101) # Fixed seed for consistent colors across runs.
random.shuffle(self.colors) # Shuffle colors to decorrelate adjacent classes.
random.seed(None) # Reset seed to default.
# Generate output tensor targets for filtered bounding boxes.
self.input_image_shape = K.placeholder(shape=(2, ))
boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors,
len(self.class_names), self.input_image_shape,
score_threshold=self.score, iou_threshold=self.iou)
return boxes, scores, classes
def detect_image(self, image):
if self.is_fixed_size:
assert self.model_image_size[0]%32 == 0, 'Multiples of 32 required'
assert self.model_image_size[1]%32 == 0, 'Multiples of 32 required'
boxed_image = letterbox_image(
image,
tuple(reversed(self.model_image_size)))
else:
new_image_size = (image.width - (image.width % 32),
image.height - (image.height % 32))
boxed_image = letterbox_image(image, new_image_size)
image_data = np.array(boxed_image, dtype='float32')
#print(image_data.shape)
image_data /= 255.
image_data = np.expand_dims(image_data, 0) # Add batch dimension.
out_boxes, out_scores, out_classes = self.sess.run(
[self.boxes, self.scores, self.classes],
feed_dict={
self.yolo_model.input: image_data,
self.input_image_shape: [image.size[1], image.size[0]],
K.learning_phase(): 0
})
return_boxs = []
for i, c in reversed(list(enumerate(out_classes))):
predicted_class = self.class_names[c]
if predicted_class != 'person' :
continue
box = out_boxes[i]
# score = out_scores[i]
x = int(box[1])
y = int(box[0])
w = int(box[3]-box[1])
h = int(box[2]-box[0])
if x < 0 :
w = w + x
x = 0
if y < 0 :
h = h + y
y = 0
return_boxs.append([x,y,w,h])
return return_boxs
def close_session(self):
self.sess.close()
5.5 多人目标跟踪.py文件
多人目标跟踪.py文件的内容如下:
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
from timeit import time
import warnings
import cv2
import numpy as np
from PIL import Image
from yolo import YOLO
from deep_sort import preprocessing
from deep_sort import nn_matching
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from tools import generate_detections as gdet
from deep_sort.detection import Detection as ddet
warnings.filterwarnings('ignore')
def main(yolo):
# 参数定义
max_cosine_distance = 0.3
nn_budget = None
nms_max_overlap = 1.0
# deep_sort 目标追踪算法
model_filename = 'model_data/mars-small128.pb'
encoder = gdet.create_box_encoder(model_filename,batch_size=1)
metric = nn_matching.NearestNeighborDistanceMetric(
"cosine", max_cosine_distance, nn_budget)
tracker = Tracker(metric)
writeVideo_flag = True
video_capture = cv2.VideoCapture(0)
if writeVideo_flag:
# Define the codec and create VideoWriter object
w = int(video_capture.get(3))
h = int(video_capture.get(4))
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
out = cv2.VideoWriter('output.avi', fourcc, 15, (w, h))
list_file = open('detection.txt', 'w')
frame_index = -1
fps = 0.0
while True:
ret, frame = video_capture.read() # frame shape 640*480*3
if ret != True:
break
t1 = time.time()
image = Image.fromarray(frame)
boxs = yolo.detect_image(image)
# print("box_num",len(boxs))
features = encoder(frame,boxs)
# score to 1.0 here).
detections = [Detection(bbox, 1.0, feature) for
bbox, feature in zip(boxs, features)]
# Run non-maxima suppression.
boxes = np.array([d.tlwh for d in detections])
scores = np.array([d.confidence for d in detections])
indices = preprocessing.non_max_suppression(boxes, nms_max_overlap, scores)
detections = [detections[i] for i in indices]
# Call the tracker
tracker.predict()
tracker.update(detections)
for track in tracker.tracks:
if not track.is_confirmed() or track.time_since_update > 1:
continue
bbox = track.to_tlbr()
cv2.rectangle(frame,
(int(bbox[0]), int(bbox[1])),
(int(bbox[2]), int(bbox[3])),
(255,255,255),
2)
cv2.putText(frame,
str(track.track_id),
(int(bbox[0]), int(bbox[1])),
0, 5e-3 * 200, (0,255,0),2)
for det in detections:
bbox = det.to_tlbr()
cv2.rectangle(frame,
(int(bbox[0]), int(bbox[1])),
(int(bbox[2]), int(bbox[3])),
(255,0,0),
2)
cv2.imshow('', frame)
if writeVideo_flag:
# save a frame
out.write(frame)
frame_index = frame_index + 1
list_file.write(str(frame_index)+' ')
if len(boxs) != 0:
for i in range(0,len(boxs)):
list_file.write(str(boxs[i][0]) + ' '+
str(boxs[i][1]) + ' '+
str(boxs[i][2]) + ' '+
str(boxs[i][3]) + ' ')
list_file.write('\n')
fps = ( fps + (1./(time.time()-t1)) ) / 2
print("fps= %f"%(fps))
# Press Q to stop!
if cv2.waitKey(1) & 0xFF == ord('q'):
break
video_capture.release()
if writeVideo_flag:
out.release()
list_file.close()
cv2.destroyAllWindows()
if __name__ == '__main__':
main(YOLO())
5.6 执行多人目标跟踪.py文件
执行
python 多人目标跟踪.py
5.7 运行结果如下图所示