Python Opencv-contrib Camshift kalman卡尔曼滤波 KCF算法 CSRT算法 目标跟踪实现

本文为原创文章,转载请注明出处。

本次课题实现目标跟踪一共用到了三个算法,分别是Camshift、Kalman、CSRT,基于Python语言的Tkinter模块实现GUI与接口设计,项目一共包含三个文件:

main.py:

# coding:utf-8
# 主模块


import Tkinter
import tkFileDialog
import cv2
import time
from PIL import ImageTk
# 导入自定义模块
import track
import utils


# 设置窗口800*480
root = Tkinter.Tk()
root.title("基于视频的实时行人追踪")
root.geometry("800x480")

# 设置背景
canvas = Tkinter.Canvas(root, width=800, height=480, highlightthickness=0, borderwidth=0)
background_image = ImageTk.PhotoImage(file="background.jpg")  # 项目本地路径(背景图片)
canvas.create_image(0, 0, anchor="nw", image=background_image)
canvas.pack()

# 显示提示
label_a = Tkinter.Label(root, text="基于视频的实时行人追踪", font=("KaiTi", 20), height=2)
label_a.pack()
canvas.create_window(400, 100, height=25, window=label_a)

# 显示路径
show_path = Tkinter.StringVar()
show_path.set("请选择一个文件夹")

# 显示路径标签
label_b = Tkinter.Label(root, textvariable=show_path, font=("Times New Roman", 15), height=2)
label_b.pack()
canvas.create_window(400, 150, window=label_b)

# 坐标库
ROI = utils.ROI()
# 路径库
path = utils.Path()


# 选择序列
def hit_button_a():
    path.init(tkFileDialog.askdirectory(title="Select Folder"))
    # 显示路径
    if path.img_path != "":
        show_path.set("文件路径:" + str(path.img_path)[:-1] + "\n序列总数:" + str(path.sum))
    else:
        show_path.set("路径错误!")


button_a = Tkinter.Button(root, text="选择序列", font=("KaiTi", 15), height=2, command=hit_button_a)
button_a.pack()
canvas.create_window(400, 200, height=20, window=button_a)


# ROI
def hit_button_b():
    # 读取首帧图像
    first_image = cv2.imread(path.pics_list[0])
    # ROI
    ROI.init_window(cv2.selectROI(windowName="ROI", img=first_image, showCrosshair=True, fromCenter=False))
    cv2.destroyAllWindows()


button_b = Tkinter.Button(root, text="标记目标", font=("KaiTi", 15), heigh=2, command=hit_button_b)
button_b.pack()
canvas.create_window(400, 250, height=20, window=button_b)


# 目标追踪

def hit_button_c():
    global camshift, kcf, csrt
    index = utils.index(path.groundtruth_path)  # 读取真值
    firstframe = True
    kalman_xy = track.KalmanFilter()
    kalman_size = track.KalmanFilter()
    bbox = [0, 0, 0, 0]

    for i in range(0, path.sum):
        start = time.time()  # 开始计时
        frame = cv2.imread(path.pics_list[i])  # 读取
        if firstframe:
            camshift = track.Camshift(frame, ROI.window)
            kcf = track.KCFtracker(frame, ROI.window)
            firstframe = False
            continue
        # camshift.update(frame)
        ok = kcf.update(frame)
        if not ok:
            mes = (bbox[0], bbox[1], bbox[2], bbox[3])
            print mes
            kcf.tracker.init(frame, mes)
            ok = kcf.update(frame)
        end = time.time()  # 结束计时
        seconds = end - start  # 处理用时
        groundtruth = index.groundtruth(i)  # 真值
        window = camshift.window
        window = kcf.window

        A = window[0] - ROI.window[0]
        B = window[1] - ROI.window[1]
        C = window[2] - ROI.window[2]
        D = window[3] - ROI.window[3]
        xy = kalman_xy.predict(A, B)
        size = kalman_size.predict(C, D)  # 卡尔曼滤波

        bbox[0] = int(ROI.window[0] + xy[0])
        bbox[1] = int(ROI.window[1] + xy[1])
        bbox[2] = int(ROI.window[2] + size[0])
        bbox[3] = int(ROI.window[3] + size[1])

        ape = index.APE(bbox, groundtruth)  # 像素误差
        aor = index.AOR(bbox, groundtruth)  # 重叠率
        # 绘制数据曲线
        # eva.draw(FPS, ape, aor, i)
        frame = utils.display(seconds, frame, bbox, ape, aor, groundtruth, truth=False)  # 跟踪框
        # 显示
        cv2.imshow("Track", frame)
        t = cv2.waitKey(20) & 0xff
        # 按空格键停止
        if t == ord(" "):
            cv2.waitKey(0)
        # 按ESC键退出
        if t == 27:
            cv2.destroyAllWindows()
            break
    cv2.destroyAllWindows()
    print ("跟踪结束!\n")


button_c = Tkinter.Button(root, text="开始追踪", font=("KaiTi", 15), heigh=2, command=hit_button_c)
button_c.pack()
canvas.create_window(400, 300, height=20, window=button_c)


root.mainloop()

自定义跟踪器模块track.py:

# coding:utf-8
# 追踪器模块


import cv2
import numpy as np


# 得到中心点
def center(points):
    x = (points[0][0] + points[1][0] + points[2][0] + points[3][0]) / 4
    y = (points[0][1] + points[1][1] + points[2][1] + points[3][1]) / 4
    return np.array([np.float32(x), np.float32(y)], np.float32)


class Camshift:
    def __init__(self, frame, ROI):
        x, y, w, h = ROI
        self.window = ROI
        roi = frame[y:y + h, x:x + w]  # ROI裁剪
        hsv_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)  # HSV转换
        mask = cv2.inRange(hsv_roi, np.array((0., 60., 32.)), np.array((180., 255., 255.)))  # 设置阈值
        self.hist = cv2.calcHist([hsv_roi], [0], mask, [180], [0, 180])  # 直方图
        cv2.normalize(self.hist, self.hist, 0, 255, cv2.NORM_MINMAX)  # 归一化

    def update(self, frame):
        term_crit = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 1, 10)  # 迭代终止标准(最多十次迭代)
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)  # HSV转换
        dst = cv2.calcBackProject([hsv], [0], self.hist, [0, 180], 1)  # 反向投影
        # cv2.imshow("dst", dst)
        # cv2.waitKey(10)
        x, y, w, h = self.window  # 跟踪框
        ret, (x, y, w, h) = cv2.CamShift(dst, (x, y, w, h), term_crit)
        self.window = (x, y, w, h)


class MILtracker:
    def __init__(self, frame, ROI):
        self.window = ROI
        self.tracker = cv2.TrackerMIL_create()
        self.tracker.init(frame, self.window)

    def update(self, frame):
        ok, self.window = self.tracker.update(frame)


class KCFtracker:
    def __init__(self, frame, ROI):
        self.window = ROI
        self.tracker = cv2.TrackerKCF_create()
        self.tracker = cv2.TrackerCSRT_create()
        self.tracker.init(frame, self.window)

    def update(self, frame):
        ok, self.window = self.tracker.update(frame)
        return ok

class CSRTtracker:
    def __init__(self, frame, ROI):
        self.window = ROI
        self.tracker = cv2.TrackerCSRT_create()
        self.tracker.init(frame, self.window)

    def update(self, frame):
        ok, self.window = self.tracker.update(frame)


class KalmanFilter:
    def __init__(self):
        self.kalman = cv2.KalmanFilter(4, 2)
        self.kalman.measurementMatrix = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], np.float32)
        self.kalman.transitionMatrix = np.array([[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
        self.kalman.processNoiseCov = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
                                               np.float32) * 0.003
        self.kalman.measurementNoiseCov = np.array([[1, 0], [0, 1]], np.float32) * 0.01

    def predict(self, x, y):
        current_mes = np.array([[np.float32(x)], [np.float32(y)]])
        self.kalman.correct(current_mes)
        current_pre = self.kalman.predict()
        return current_pre

自定义的工具模块utils.py:

# coding:utf-8
# 工具模块


import cv2
import os
import re


# 目录存储模块
class Path:
    # 存储文件路径
    def __init__(self):
        self.img_path = ""
        self.groundtruth_path = ""
        # 目录
        self.inpics_list = []
        # 绝对路径目录
        self.pics_list = []
        self.sum = 0

    # 初始化文件路径
    def init(self, path):
        # 请选择包含img和groundtruth的总文件夹
        if path != '':
            self.img_path = path + "/img/"
            self.groundtruth_path = path + "/groundtruth.txt"
            self.inpics_list = os.listdir(self.img_path)
            self.inpics_list.sort()
        # 目录统计
        self.sum = len(self.inpics_list)
        # 绝对路径
        self.pics_list = [self.img_path + x for x in self.inpics_list]


# 坐标储存模块
class ROI:
    # 存储坐标
    def __init__(self):
        self.x = 0
        self.y = 0
        self.width = 0
        self.height = 0
        self.window = []

    # 单坐标的初始化
    def init(self, x, y, width, height):
        self.x = x
        self.y = y
        self.width = width
        self.height = height
        self.window = (x, y, width, height)

    # 窗口坐标的初始化
    def init_window(self, window):
        self.x = window[0]
        self.y = window[1]
        self.width = window[2]
        self.height = window[3]
        self.window = (window[0], window[1], window[2], window[3])


# 评价指标模块
class index:
    def __init__(self, path):
        self.fps = []
        self.ape = []
        self.aor = []
        self.n = []
        # 载入真值
        self.lines = open(path).readlines()

    # 得到真值
    def groundtruth(self, i):
        line = [x for x in self.lines]
        # 切割
        window = [0, 0, 0, 0]
        for n in range(0, 4):
            window[n] = int(re.split("[,\n\t ]", line[i])[n])
        return window

    # 像素误差
    @staticmethod
    def APE(window, bbox):
        x1, y1, w1, h1 = window
        x2, y2, w2, h2 = bbox
        # 跟踪框中心
        center = [int(x1 + 1 / 2 * w1), int(y1 + 1 / 2 * h1)]
        # 真值中心
        truth_center = [int(x2 + 1 / 2 * w2), int(y2 + 1 / 2 * h2)]
        # 计算像素误差
        ape = pow(pow(center[0] - truth_center[0], 2) + pow(center[1] - truth_center[1], 2), .2)
        ape = round(ape, 2)
        return ape

    # 重叠率
    @staticmethod
    def AOR(window, bbox):
        x1, y1, w1, h1 = window
        x2, y2, w2, h2 = bbox
        col = min(x1 + w1, x2 + w2) - max(x1, x2)
        row = min(y1 + h1, y2 + h2) - max(y1, y2)
        intersection = col * row
        area1 = w1 * h1
        area2 = w2 * h2
        coincide = intersection * 1.0 / (area1 + area2 - intersection) * 100
        aor = round(coincide, 2)
        return aor

    # 绘制数据曲线
    def draw(self, fps, ape, aor, number):
        self.fps.append(fps)
        self.ape.append(ape)
        self.aor.append(aor)
        self.n.append(number)


# 跟踪框显示模块
def display(seconds, img, window, ape, aor, groundtruth, truth=False):
    window = [int(x) for x in window]
    x, y, w, h = window
    # 跟踪框
    img = cv2.rectangle(img, (x, y), (x + w, y + h), (255, 0, 0), 2)
    if truth:
        a, b, c, d = groundtruth
        img = cv2.rectangle(img, (a, b), (a + c, b + d), (0, 255, 0), 2)
    # 中心点
    xc = (x + w / 2)
    yc = (y + h / 2)
    cv2.circle(img, (xc, yc), 3, (255, 0, 0), -1)
    # 坐标
    text = cv2.FONT_HERSHEY_COMPLEX_SMALL
    size = 1
    # text = cv2.FONT_ITALIC
    cv2.putText(img, ('X=' + str(xc)), (10, 20), text, size, (0, 0, 255), 1, cv2.LINE_AA)
    cv2.putText(img, ('Y=' + str(yc)), (10, 50), text, size, (0, 0, 255), 1, cv2.LINE_AA)
    # FPS
    fps = 1 / seconds
    cv2.putText(img, ('FPS = ' + str(int(fps))), (10, 80), text, size, (0, 255, 0), 1, cv2.LINE_AA)
    cv2.putText(img, ('APE = ' + str(ape)) + 'pixels', (10, 110), text, size, (0, 255, 255), 1, cv2.LINE_AA)
    cv2.putText(img, ('AOR = ' + str(aor) + '%'), (10, 140), text, size, (255, 0, 255), 1, cv2.LINE_AA)
    return img

def dis(window, img):
    window = [int(x) for x in window]
    x, y, w, h = window
    img = cv2.rectangle(img, (x, y), (x + w, y + h), (255, 0, 0), 2)
    return img

注:

1.在项目目录下保存一张GUI界面的背景图像background.jpg。

2.在选择样本序列时,格式为:所选定文件夹包含子文件夹img,保存有0001.jpg~…的所有序列,子文件groundtruth.txt真值文件。

3.务必使用低版本(未知原因)的Opencv-contrib,否则不能使用CSRT跟踪器。

TBD.

你可能感兴趣的:(Python Opencv-contrib Camshift kalman卡尔曼滤波 KCF算法 CSRT算法 目标跟踪实现)