MeanShift算法,又称为均值漂移算法,采用基于颜色特征的核密度估计,寻找局部最优,使得跟踪过程中对目标旋转,小范围遮挡不敏感。
MeanShift的本质是一个迭代的过程,在一组数据的密度分布中,使用无参密度估计寻找到局部极值(不需要事先知道样本数据的概率密度分布函数,完全依靠对样本点的计算)。
在d维空间中,任选一个点,然后以这个点为圆心,h为半径做一个高维球,因为有d维,d可能大于2,所以是高维球。落在这个球内的所有点和圆心都会产生一个向量,向量是以圆心为起点落在球内的点位终点。然后把这些向量都相加。相加的结果就是下图中黄色箭头表示的MeanShift向量:
然后,再以这个MeanShift 向量的终点为圆心,继续上述过程,又可以得到一个MeanShift 向量:
不断地重复这样的过程,可以得到一系列连续的MeanShift 向量,这些向量首尾相连,最终可以收敛到概率密度最大得地方(一个点):
从上述的过程可以看出,MeanShift 算法的过程就是:从起点开始,一步步到达样本特征点的密度中心。
1.获取待跟踪对象
获取初始目标框(RoI)位置信息(x,y,w,h),截取 RoI图像区域
# 初始化RoI位置信息
track_window = (c,r,w,h)
# 截取图片RoI
roi = img[r:r+h, c:c+w]
2.转换颜色空间
将BGR格式的RoI图像转换为HSV格式,对 HSV格式的图像进行滤波,去除低亮度和低饱和度的部分。
在 HSV 颜色空间中要比在 BGR 空间中更容易表示一个特定颜色。在 OpenCV 的 HSV 格式中,H(色度)的取值范围是 [0,179], S(饱和度)的取值范围 [0,255],V(亮度)的取值范围 [0,255]。
# 转换到HSV
hsv_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
# 设定滤波的阀值
lower = np.array([0.,130.,32.])
upper = np.array([180.,255.,255.])
# 根据阀值构建掩模
mask = cv2.inRange(hsv,lower, upper)
3.获取色调统计直方图
# 获取色调直方图
roi_hist = cv2.calcHist([hsv_roi],[0],mask,[180],[0,180])
# 直方图归一化
cv2.normalize(roi_hist,roi_hist,0,180,cv2.NORM_MINMAX)
cv2.calcHist的原型为:
cv2.calcHist(images, channels, mask, histSize, ranges[, hist[, accumulate ]])
images: 待统计的图像,必须用方括号括起来,
channels:用于计算直方图的通道,这里使用色度通道
mask:滤波掩模
histSize:表示这个直方图分成多少份(即多少个直方柱)
ranges:表示直方图中各个像素的值的范围
4.在新的一帧中寻找跟踪对象
# 读入目标图片
ret, frame = cap.read()
# 转换到HSV
hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
# 获取目标图片的反向投影
dst = cv2.calcBackProject([hsv],[0],roi_hist,[0,180],1)
# 定义迭代终止条件
term_crit = ( cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 1 )
# 计算得到迭代次数和目标位置
ret, track_window = cv2.meanShift(dst, track_window, term_crit)
def meanShift(probImage, window, criteria)
probImage:输入反向投影直方图
window:需要移动的矩形(ROI)
criteria:对meanshift迭代过程进行控制的初始参量
其中,criteria参数如下:
type:判定迭代终止的条件类型:
COUNT:按最大迭代次数作为求解结束标志
EPS:按达到某个收敛的阈值作为求解结束标志
COUNT + EPS:两个条件达到一个就算结束
maxCount:具体的最大迭代的次数
epsilon:具体epsilon的收敛阈值
反向投影图输出的是一张概率密度图,与输入图像大小相同,每一个像素值代表了输入图像上对应点属于目标对象的概率,像素点越亮,代表这个点属于目标物体的概率越大。
跟踪目标:
跟踪目标在下一帧中的反向投影:
import numpy as np
import cv2
class MeanShiftTracer:
def __init__(self, id):
# Stop criteria for the iterative search algorithm.
self._term_crit = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 1)
self._roi_hist = None
self.predict_count = 0
self.frame = None
self.frame_begin_id = id
self.frame_end_id = id
self.roi_xywh = None
def _log_last_correct(self, frame, frame_id, xywh):
x, y, w, h = xywh
self.correct_box = (x, y, w, h)
self.correct_img = frame[y:y + h, x:x + w]
self.correct_id = frame_id
def correct(self, frame, frame_id, xywh):
self._log_last_correct(frame,frame_id, xywh)
self._refresh_roi(frame, frame_id, xywh)
self.predict_count = 0
def predict(self, frame, frame_id):
hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
dst = cv2.calcBackProject([hsv], [0], self._roi_hist, [0, 180], 1)
ret, track_window = cv2.meanShift(dst, self.roi_xywh, self._term_crit)
self._refresh_roi(frame, frame_id, track_window)
self.predict_count += 1
return track_window
def _refresh_roi(self, frame, frame_id, xywh):
x, y, w, h = xywh
roi = frame[y:y + h, x:x + w]
hsv_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
mask = cv2.inRange(hsv_roi, np.array((0., 60., 32.)), np.array((180., 255., 255.)))
roi_hist = cv2.calcHist([hsv_roi], [0], mask, [180], [0, 180])
cv2.normalize(roi_hist, roi_hist, 0, 180, cv2.NORM_MINMAX)
self.roi_xywh = (x, y, w, h)
self._roi_hist = roi_hist
self.frame = frame
self.frame_end_id = frame_id
def get_roi_info(self):
return {'correct_box': self.correct_box,
'correct_img': self.correct_img,
'correct_id': self.correct_id,
'beginId': self.frame_begin_id,
'endId': self.frame_end_id}
import numpy as np
import cv2
class TracerManager:
def __init__(self, image_shape, trace_tool, trace_margin, max_predict):
"""
:param image_shape: (height,width)
:param trace_tool: MeanShiftTracer
:param trace_margin: (0,0,30,50)(px)(left,top,right,bottom)
:param max_predict: 3 (times)
"""
self._tracers = []
self._trace_tool = trace_tool
self._max_predict = max_predict
self._image_shape = image_shape
self.trace_margin = trace_margin
def _calc_iou(self, A, B):
"""
:param A: [x1, y1, x2, y2]
:param B: [x1, y1, x2, y2]
:return: IoU
"""
IoU = 0
iw = min(A[2], B[2]) - max(A[0], B[0])
if iw > 0:
ih = min(A[3], B[3]) - max(A[1], B[1])
if ih > 0:
A_area = (A[2] - A[0]) * (A[3] - A[1])
B_area = (B[2] - B[0]) * (B[3] - B[1])
uAB = float(A_area + B_area - iw * ih)
IoU = iw * ih / uAB
return IoU
def box_in_margin(self, box):
in_bottom = (self._image_shape[0] - (box[1] + box[3])) < self.trace_margin[3]
in_right = (self._image_shape[1] - (box[0] + box[2])) < self.trace_margin[2]
return in_bottom or in_right
def _get_box_tracer_iou(self, A, B):
a = (A[0], A[1], A[0] + A[2], A[1] + A[3])
b = (B[0], B[1], B[0] + B[2], B[1] + B[3])
return self._calc_iou(a, b)
def _check_over_trace(self):
remove_tracer = []
trace_info = []
for t in self._tracers:
if t.predict_count > self._max_predict:
remove_tracer.append(t)
if t.frame_end_id != t.frame_begin_id:
trace_info.append(t.get_roi_info())
for t in remove_tracer:
self._tracers.remove(t)
return trace_info
def _get_tracer(self, box):
tracer = None
maxIoU = 0
for t in self._tracers:
iou = self._get_box_tracer_iou(box, t.roi_xywh)
if iou > maxIoU:
tracer = t
maxIoU = iou
return tracer
def update_tracer(self, frame, frame_id, boxes):
trace_info = self._check_over_trace()
for box in boxes:
if self.box_in_margin(box):
continue
tracer = self._get_tracer(box)
if tracer is not None:
tracer.correct(frame, frame_id, box)
else:
tracer = self._trace_tool(frame_id)
tracer.correct(frame, frame_id, box)
self._tracers.append(tracer)
return trace_info
def trace(self, frame, frame_id):
track_windows = []
for t in self._tracers:
window = t.predict(frame, frame_id)
track_windows.append(window)
return track_windows
检测与跟踪以1:1的比例交替进行。
import cv2
import numpy as np
import os.path
import Tracer
class car_detector:
def __init__(self, cascade_file):
if not os.path.isfile(cascade_file):
raise RuntimeError("%s: not found" % cascade_file)
self._cascade = cv2.CascadeClassifier(cascade_file)
def _detect_cars(self, image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
gray = cv2.equalizeHist(gray)
cars = self._cascade.detectMultiScale(gray, scaleFactor=1.3, minNeighbors=15, minSize=(60, 60))
return cars
def _show_trace_object(self, infos):
for info in infos:
title = "%d - %d from frame: %d" % (info['beginId'], info['endId'], info['correct_id'])
cv2.imshow(title, info['correct_img'])
cv2.waitKey(1)
def _get_area_invalid_mark(self, img_shape, margin):
area = np.zeros(img_shape,np.uint8)
h, w = img_shape[:2]
disable_bg_color = (0, 0, 80)
disable_fg_color = (0, 0, 255)
cv2.rectangle(area, (0, h-margin[3]), (w, h), disable_bg_color, -1)
cv2.putText(area, "Invalid Region", (w-220, h-20), cv2.FONT_HERSHEY_SIMPLEX, 1, disable_fg_color, 2)
return area
def _show_trace_state(self, image, id, tracer, state, boxes, mark):
image = cv2.addWeighted(mark, 0.5, image, 1, 0)
title = 'frame : %s [%s]' % (state, id)
colors = {'detect': (0, 255, 0), 'trace': (255, 255, 0), 'invalid': (150, 150, 150), 'title_bg': (0, 0, 0)}
for (x, y, w, h) in boxes:
if tracer.box_in_margin((x, y, w, h)):
cv2.rectangle(image, (x, y), (x + w, y + h),colors['invalid'], 2)
else:
cv2.rectangle(image, (x, y), (x + w, y + h), colors[state], 2)
cv2.rectangle(image, (10, 20), (250, 50), colors['title_bg'], -1)
cv2.putText(image, title, (30, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.6, colors[state],2)
cv2.imshow("result", image)
cv2.waitKey(1)
def trace_detect_video(self, video_path, trace_rate = 1):
cap = cv2.VideoCapture(video_path)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
start_frame = 0
invalid_margin = (0, 0, 0, 100)
mark = self._get_area_invalid_mark((h, w, 3), invalid_margin)
tracer = Tracer.TracerManager((h, w), Tracer.MeanShiftTracer, invalid_margin, trace_rate + 5)
warm = False
while True:
ret, image = cap.read()
start_frame += 1
if not ret: return
result = image.copy()
if not warm or start_frame % (trace_rate + 1) == 0:
warm = True
cars = self._detect_cars(image)
self._show_trace_state(result, start_frame, tracer, 'detect', cars, mark)
trace_obj = tracer.update_tracer(image, start_frame, cars)
self._show_trace_object(trace_obj)
else:
cars = tracer.trace(image, start_frame)
self._show_trace_state(result, start_frame, tracer, 'trace', cars, mark)
if __name__ == "__main__":
car_cascade_lbp_21 = './train/cascade_lbp_21/cascade.xml'
video_path = "./test.mp4"
detect = car_detector(car_cascade_lbp_21)
detect.trace_detect_video(video_path)
优点:
算法计算量不大,在目标区域已知的情况下完全可以做到实时跟踪;
采用核函数直方图模型,对边缘遮挡、目标旋转、变形和背景运动不敏感。
缺点:
跟踪过程中由于窗口宽度大小保持不变,框出的区域不会随着目标的扩大(或缩小)而扩大(或缩小);
当目标速度较快时,跟踪效果不好;
直方图特征在目标颜色特征描述方面略显匮乏,缺少空间信息;