视频教学:http://v.youku.com/v_show/id_XNzQwMTMwOTQ0.html?spm=a2h0j.11185381.listitem_page1.5!3~A
卡尔曼滤波,比如二维图像,对于两个维度的数据都是有高斯噪声的,而且二维数据是一个线性的关系(曲线也行,协方差是一个传递过程,可以看成许多段线性的),就能使变的平滑。总之卡尔曼滤波是针对多维度的每个维度都有高斯噪声的模型的。
最近在研究语音增强算法,这两天正在看卡尔曼滤波,看到一个关于卡尔曼理论很好的帖子:How a Kalman filter works, in pictures,基本上把卡尔曼滤波的核心思想讲明白了,而且通俗易懂,特此推荐,本博客就不介绍公式了,只谈一下自己对卡尔曼滤波思想的理解,如果要看公式推导,建议直接看上述帖子。
卡尔曼滤波运用于具有不确定性的动态系统状态估计,该系统一般具有两个状态,一个是通过状态转移方程得到的预测状态,另一个是通过传感器得到的观察状态,卡尔曼滤波前提是假设这两个状态都符合高斯分布,并且当前状态只与上一时刻状态有关,与历史状态无关。然后组合这两个高斯分布,得到我们的最优估计,组合方式就是两个高斯分布相乘,得到一个新的高斯分布,新的高斯分布均值即是我们的最有估计。
根据上述思想,我们首先有预测部分,预测部分主要是根据前一时刻的状态,通过状态转移矩阵,实现对当前时刻的状态的估计,同时也会利用状态转移矩阵更新对协方差的更新,这样就得到了当前状态估计。当前估计状态通过传感器的映射矩阵映射,使当前估计状态映射到观察空间,也就是相当于预测的观察值,同时也利用映射矩阵更新协方差矩阵。
对于实际观察状态,符合高斯分布,均值为观察值,方差为传感器的不确定度。
得到两个状态后,利用高斯相乘计算出来的卡尔曼增益,进行状态的最有估计,大致思想原理是这样的,具体公式推导详见:How a Kalman filter works, in pictures。
绿色为测量到的鼠标坐标(位置)
红色为卡尔曼滤波器预测的鼠标坐标(位置)
import cv2
import numpy as np
import random
# 创建一个空帧,定义(700, 700, 3)画图区域
frame = np.zeros((700, 700, 3), np.uint8)
# 初始化测量坐标和鼠标运动预测的数组
last_measurement = current_measurement = np.array((2, 1), np.float32)
last_prediction = current_prediction = np.zeros((2, 1), np.float32)
# 定义鼠标回调函数,用来绘制跟踪结果
def mousemove(event, x, y, s, p):
global frame, current_measurement, measurements, last_measurement, current_prediction, last_prediction
x = x+random.randint(0, 100)
y = y + random.randint(0, 100)
last_prediction = current_prediction # 把当前预测存储为上一次预测
last_measurement = current_measurement # 把当前测量存储为上一次测量
current_measurement = np.array([[np.float32(x)], [np.float32(y)]]) # 当前测量
kalman.correct(current_measurement) # 用当前测量来校正卡尔曼滤波器
current_prediction = kalman.predict() # 计算卡尔曼预测值,作为当前预测
lmx, lmy = last_measurement[0], last_measurement[1] # 上一次测量坐标
cmx, cmy = current_measurement[0], current_measurement[1] # 当前测量坐标
lpx, lpy = last_prediction[0], last_prediction[1] # 上一次预测坐标
cpx, cpy = current_prediction[0], current_prediction[1] # 当前预测坐标
# 绘制从上一次测量到当前测量以及从上一次预测到当前预测的两条线
cv2.line(frame, (lmx, lmy), (cmx, cmy), (255, 0, 0)) # 蓝色线为测量值
cv2.line(frame, (lpx, lpy), (cpx, cpy), (255, 0, 255)) # 粉色线为预测值
# 窗口初始化
cv2.namedWindow("kalman_tracker")
# opencv采用setMouseCallback函数处理鼠标事件,具体事件必须由回调(事件)函数的第一个参数来处理,该参数确定触发事件的类型(点击、移动等)
cv2.setMouseCallback("kalman_tracker", mousemove)
kalman = cv2.KalmanFilter(4, 2) # 4:状态数,包括(x,y,dx,dy)坐标及速度(每次移动的距离);2:观测量,(能看到的是坐标值 接受的参数,需要的书籍的维度)
kalman.measurementMatrix = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], np.float32) # 系统测量矩阵 (需要得到的数据,与上面的2对应)
kalman.transitionMatrix = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], np.float32) # 状态转移矩阵 (利用t时刻的协方差推测t+1时刻的协方差)
kalman.processNoiseCov = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], np.float32)*0.001
# 模型系统的噪声 (kalman.processNoiseCov为模型系统的噪声,噪声越大,预测结果越不稳定,越容易接近模型系统预测值,且单步变化越大,相反,若噪声小,则预测结果与上个计算结果相差不大。
while True:
cv2.imshow("kalman_tracker", frame)
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
break
cv2.destroyAllWindows()
用上面的方法就行,kalman.processNoiseCov表示的是恢复测量值的高斯分布,越大表示震动越大,越小越平稳。
import cv2
import numpy as np
import matplotlib.pyplot as plt
pos = np.array([
[10, 50],
[12, 49],
[11, 52],
[13, 52.2],
[12.9, 50]], np.float32)
'''
它有3个输入参数,dynam_params:状态空间的维数,这里为2;measure_param:测量值的维数,这里也为2; control_params:控制向量的维数,默认为0。由于这里该模型中并没有控制变量,因此也为0。
'''
kalman = cv2.KalmanFilter(2,2)
kalman.measurementMatrix = np.array([[1,0],[0,1]],np.float32)
kalman.transitionMatrix = np.array([[1,0],[0,1]], np.float32)
kalman.processNoiseCov = np.array([[1,0],[0,1]], np.float32) * 1e-3
kalman.measurementNoiseCov = np.array([[1,0],[0,1]], np.float32) * 0.01
'''
kalman.measurementNoiseCov为测量系统的协方差矩阵,方差越小,预测结果越接近测量值,
协方差矩阵的变化量,相信下一秒的协方差的变化也是很小的;定义状态转移协方差矩阵,这里我们把协方差设置的很小,因为觉得状态转移矩阵准确度高# ,
kalman.processNoiseCov为模型系统的噪声,噪声越大,预测结果越不稳定,越容易接近模型系统预测值,且单步变化越大,相反,若噪声小,则预测结果与上个计算结果相差不大。
'''
kalman.statePre = np.array([[6],[6]],np.float32)
for i in range(len(pos)):
mes = np.reshape(pos[i,:],(2,1))
x = kalman.correct(mes)
y = kalman.predict()
print (kalman.statePost[0],kalman.statePost[1])
print (kalman.statePre[0],kalman.statePre[1])
print ('measurement:\t',mes[0],mes[1])
print ('correct:\t',x[0],x[1])
print ('predict:\t',y[0],y[1])
print ('='*30)