机器学习-Mean shift算法详解和实现

 http://blog.csdn.net/jinshengtao/article/details/30258833

 这次将介绍基于MeanShift的目标跟踪算法,首先谈谈简介,然后给出算法实现流程,最后实现了一个单目标跟踪的MeanShift算法【matlab/c两个版本】

      csdn贴公式比较烦,原谅我直接截图了…

 

一、简介

     首先扯扯无参密度估计理论,无参密度估计也叫做非参数估计,属于数理统计的一个分支,和参数密度估计共同构成了概率密度估计方法。参数密度估计方法要求特征空间服从一个已知的概率密度函数,在实际的应用中这个条件很难达到。而无参数密度估计方法对先验知识要求最少,完全依靠训练数据进行估计,并且可以用于任意形状的密度估计。所以依靠无参密度估计方法,即不事先规定概率密度函数的结构形式,在某一连续点处的密度函数值可由该点邻域中的若干样本点估计得出。常用的无参密度估计方法有:直方图法、最近邻域法和核密度估计法。

     MeanShift算法正是属于核密度估计法,它不需要任何先验知识而完全依靠特征空间中样本点的计算其密度函数值。对于一组采样数据,直方图法通常把数据的值域分成若干相等的区间,数据按区间分成若干组,每组数据的个数与总参数个数的比率就是每个单元的概率值;核密度估计法的原理相似于直方图法,只是多了一个用于平滑数据的核函数。采用核函数估计法,在采样充分的情况下,能够渐进地收敛于任意的密度函数,即可以对服从任何分布的数据进行密度估计。

     然后谈谈MeanShift的基本思想及物理含义:

机器学习-Mean shift算法详解和实现_第1张图片

机器学习-Mean shift算法详解和实现_第2张图片

机器学习-Mean shift算法详解和实现_第3张图片

    此外,从公式1中可以看到,只要是落入Sh的采样点,无论其离中心x的远近,对最终的Mh(x)计算的贡献是一样的。然而在现实跟踪过程中,当跟踪目标出现遮挡等影响时,由于外层的像素值容易受遮挡或背景的影响,所以目标模型中心附近的像素比靠外的像素更可靠。因此,对于所有采样点,每个样本点的重要性应该是不同的,离中心点越远,其权值应该越小。故引入核函数和权重系数来提高跟踪算法的鲁棒性并增加搜索跟踪能力。

      接下来,谈谈核函数:

机器学习-Mean shift算法详解和实现_第4张图片

机器学习-Mean shift算法详解和实现_第5张图片

    核函数也叫窗口函数,在核估计中起到平滑的作用。常用的核函数有:Uniform,Epannechnikov,Gaussian等。本文算法只用到了Epannechnikov,它数序定义如下:

机器学习-Mean shift算法详解和实现_第6张图片

二、基于MeanShift的目标跟踪算法

     基于均值漂移的目标跟踪算法通过分别计算目标区域和候选区域内像素的特征值概率得到关于目标模型和候选模型的描述,然后利用相似函数度量初始帧目标模型和当前帧的候选模版的相似性,选择使相似函数最大的候选模型并得到关于目标模型的Meanshift向量,这个向量正是目标由初始位置向正确位置移动的向量。由于均值漂移算法的快速收敛性,通过不断迭代计算Meanshift向量,算法最终将收敛到目标的真实位置,达到跟踪的目的。

     下面通过图示直观的说明MeanShift跟踪算法的基本原理。如下图所示:目标跟踪开始于数据点xi0(空心圆点xi0,xi1,…,xiN表示的是中心点,上标表示的是的迭代次数,周围的黑色圆点表示不断移动中的窗口样本点,虚线圆圈代表的是密度估计窗口的大小)。箭头表示样本点相对于核函数中心点的漂移向量,平均的漂移向量会指向样本点最密集的方向,也就是梯度方向。因为 Meanshift 算法是收敛的,因此在当前帧中通过反复迭代搜索特征空间中样本点最密集的区域,搜索点沿着样本点密度增加的方向“漂移”到局部密度极大点点xiN,也就是被认为的目标位置,从而达到跟踪的目的,MeanShift 跟踪过程结束。

机器学习-Mean shift算法详解和实现_第7张图片

 

 

机器学习-Mean shift算法详解和实现_第8张图片

机器学习-Mean shift算法详解和实现_第9张图片

机器学习-Mean shift算法详解和实现_第10张图片

机器学习-Mean shift算法详解和实现_第11张图片

机器学习-Mean shift算法详解和实现_第12张图片

运动目标的实现过程【具体算法】:

机器学习-Mean shift算法详解和实现_第13张图片

 

三、代码实现

说明:

1.       RGB颜色空间刨分,采用16*16*16的直方图

2.       目标模型和候选模型的概率密度计算公式参照上文

3.       opencv版本运行:按P停止,截取目标,再按P,进行单目标跟踪

4.       Matlab版本,将视频改为图片序列,第一帧停止,手工标定目标,双击目标区域,进行单目标跟踪。

 

matlab版本:

 

[plain]  view plain  copy
 
  1. function [] = select()  
  2. close all;  
  3. clear all;  
  4. %%%%%%%%%%%%%%%%%%根据一幅目标全可见的图像圈定跟踪目标%%%%%%%%%%%%%%%%%%%%%%%  
  5. I=imread('result72.jpg');  
  6. figure(1);  
  7. imshow(I);  
  8.   
  9.   
  10. [temp,rect]=imcrop(I);  
  11. [a,b,c]=size(temp);         %a:row,b:col  
  12.   
  13.   
  14. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%计算目标图像的权值矩阵%%%%%%%%%%%%%%%%%%%%%%%  
  15. y(1)=a/2;  
  16. y(2)=b/2;  
  17. tic_x=rect(1)+rect(3)/2;  
  18. tic_y=rect(2)+rect(4)/2;  
  19. m_wei=zeros(a,b);%权值矩阵  
  20. h=y(1)^2+y(2)^2 ;%带宽  
  21.   
  22.   
  23. for i=1:a  
  24.     for j=1:b  
  25.         dist=(i-y(1))^2+(j-y(2))^2;  
  26.         m_wei(i,j)=1-dist/h; %epanechnikov profile  
  27.     end  
  28. end  
  29. C=1/sum(sum(m_wei));%归一化系数  
  30.   
  31.   
  32. %计算目标权值直方图qu  
  33. %hist1=C*wei_hist(temp,m_wei,a,b);%target model  
  34. hist1=zeros(1,4096);  
  35. for i=1:a  
  36.     for j=1:b  
  37.         %rgb颜色空间量化为16*16*16 bins  
  38.         q_r=fix(double(temp(i,j,1))/16);  %fix为趋近0取整函数  
  39.         q_g=fix(double(temp(i,j,2))/16);  
  40.         q_b=fix(double(temp(i,j,3))/16);  
  41.         q_temp=q_r*256+q_g*16+q_b;            %设置每个像素点红色、绿色、蓝色分量所占比重  
  42.         hist1(q_temp+1)= hist1(q_temp+1)+m_wei(i,j);    %计算直方图统计中每个像素点占的权重  
  43.     end  
  44. end  
  45. hist1=hist1*C;  
  46. rect(3)=ceil(rect(3));  
  47. rect(4)=ceil(rect(4));  
  48.   
  49.   
  50.   
  51.   
  52. %%%%%%%%%%%%%%%%%%%%%%%%%读取序列图像  
  53. myfile=dir('D:\matlab7\work\mean shift\image\*.jpg');  
  54. lengthfile=length(myfile);  
  55.   
  56.   
  57. for l=1:lengthfile  
  58.     Im=imread(myfile(l).name);  
  59.     num=0;  
  60.     Y=[2,2];  
  61.       
  62.       
  63.     %%%%%%%mean shift迭代  
  64.     while((Y(1)^2+Y(2)^2>0.5)&num<20)   %迭代条件  
  65.         num=num+1;  
  66.         temp1=imcrop(Im,rect);  
  67.         %计算侯选区域直方图  
  68.         %hist2=C*wei_hist(temp1,m_wei,a,b);%target candidates pu  
  69.         hist2=zeros(1,4096);  
  70.         for i=1:a  
  71.             for j=1:b  
  72.                 q_r=fix(double(temp1(i,j,1))/16);  
  73.                 q_g=fix(double(temp1(i,j,2))/16);  
  74.                 q_b=fix(double(temp1(i,j,3))/16);  
  75.                 q_temp1(i,j)=q_r*256+q_g*16+q_b;  
  76.                 hist2(q_temp1(i,j)+1)= hist2(q_temp1(i,j)+1)+m_wei(i,j);  
  77.             end  
  78.         end  
  79.         hist2=hist2*C;  
  80.         figure(2);  
  81.         subplot(1,2,1);  
  82.         plot(hist2);  
  83.         hold on;  
  84.           
  85.         w=zeros(1,4096);  
  86.         for i=1:4096  
  87.             if(hist2(i)~=0) %不等于  
  88.                 w(i)=sqrt(hist1(i)/hist2(i));  
  89.             else  
  90.                 w(i)=0;  
  91.             end  
  92.         end  
  93.           
  94.           
  95.           
  96.         %变量初始化  
  97.         sum_w=0;  
  98.         xw=[0,0];  
  99.         for i=1:a;  
  100.             for j=1:b  
  101.                 sum_w=sum_w+w(uint32(q_temp1(i,j))+1);  
  102.                 xw=xw+w(uint32(q_temp1(i,j))+1)*[i-y(1)-0.5,j-y(2)-0.5];  
  103.             end  
  104.         end  
  105.         Y=xw/sum_w;  
  106.         %中心点位置更新  
  107.         rect(1)=rect(1)+Y(2);  
  108.         rect(2)=rect(2)+Y(1);  
  109.     end  
  110.       
  111.       
  112.     %%%跟踪轨迹矩阵%%%  
  113.     tic_x=[tic_x;rect(1)+rect(3)/2];  
  114.     tic_y=[tic_y;rect(2)+rect(4)/2];  
  115.       
  116.     v1=rect(1);  
  117.     v2=rect(2);  
  118.     v3=rect(3);  
  119.     v4=rect(4);  
  120.     %%%显示跟踪结果%%%  
  121.     subplot(1,2,2);  
  122.     imshow(uint8(Im));  
  123.     title('目标跟踪结果及其运动轨迹');  
  124.     hold on;  
  125.     plot([v1,v1+v3],[v2,v2],[v1,v1],[v2,v2+v4],[v1,v1+v3],[v2+v4,v2+v4],[v1+v3,v1+v3],[v2,v2+v4],'LineWidth',2,'Color','r');  
  126.     plot(tic_x,tic_y,'LineWidth',2,'Color','b');  
  127.       
  128.       
  129. end  


 运行结果:

机器学习-Mean shift算法详解和实现_第14张图片

 

 

 

opencv版本:

[cpp]  view plain  copy
 
  1. #include "stdafx.h"  
  2. #include "cv.h"  
  3. #include "highgui.h"  
  4. #define  u_char unsigned char  
  5. #define  DIST 0.5  
  6. #define  NUM 20  
  7.   
  8. //全局变量  
  9. bool pause = false;  
  10. bool is_tracking = false;  
  11. CvRect drawing_box;  
  12. IplImage *current;  
  13. double *hist1, *hist2;  
  14. double *m_wei;                                                                  //权值矩阵  
  15. double C = 0.0;                                                                //归一化系数  
  16.   
  17. void init_target(double *hist1, double *m_wei, IplImage *current)  
  18. {  
  19.     IplImage *pic_hist = 0;  
  20.     int t_h, t_w, t_x, t_y;  
  21.     double h, dist;  
  22.     int i, j;  
  23.     int q_r, q_g, q_b, q_temp;  
  24.       
  25.     t_h = drawing_box.height;  
  26.     t_w = drawing_box.width;  
  27.     t_x = drawing_box.x;  
  28.     t_y = drawing_box.y;  
  29.   
  30.     h = pow(((double)t_w)/2,2) + pow(((double)t_h)/2,2);            //带宽  
  31.     pic_hist = cvCreateImage(cvSize(300,200),IPL_DEPTH_8U,3);     //生成直方图图像  
  32.   
  33.     //初始化权值矩阵和目标直方图  
  34.     for (i = 0;i < t_w*t_h;i++)  
  35.     {  
  36.         m_wei[i] = 0.0;  
  37.     }  
  38.   
  39.     for (i=0;i<4096;i++)  
  40.     {  
  41.         hist1[i] = 0.0;  
  42.     }  
  43.   
  44.     for (i = 0;i < t_h; i++)  
  45.     {  
  46.         for (j = 0;j < t_w; j++)  
  47.         {  
  48.             dist = pow(i - (double)t_h/2,2) + pow(j - (double)t_w/2,2);  
  49.             m_wei[i * t_w + j] = 1 - dist / h;   
  50.             //printf("%f\n",m_wei[i * t_w + j]);  
  51.             C += m_wei[i * t_w + j] ;  
  52.         }  
  53.     }  
  54.   
  55.     //计算目标权值直方  
  56.     for (i = t_y;i < t_y + t_h; i++)  
  57.     {  
  58.         for (j = t_x;j < t_x + t_w; j++)  
  59.         {  
  60.             //rgb颜色空间量化为16*16*16 bins  
  61.             q_r = ((u_char)current->imageData[i * current->widthStep + j * 3 + 2]) / 16;  
  62.             q_g = ((u_char)current->imageData[i * current->widthStep + j * 3 + 1]) / 16;  
  63.             q_b = ((u_char)current->imageData[i * current->widthStep + j * 3 + 0]) / 16;  
  64.             q_temp = q_r * 256 + q_g * 16 + q_b;  
  65.             hist1[q_temp] =  hist1[q_temp] +  m_wei[(i - t_y) * t_w + (j - t_x)] ;  
  66.         }  
  67.     }  
  68.   
  69.     //归一化直方图  
  70.     for (i=0;i<4096;i++)  
  71.     {  
  72.         hist1[i] = hist1[i] / C;  
  73.         //printf("%f\n",hist1[i]);  
  74.     }  
  75.   
  76.     //生成目标直方图  
  77.     double temp_max=0.0;  
  78.   
  79.     for (i = 0;i < 4096;i++)         //求直方图最大值,为了归一化  
  80.     {  
  81.         //printf("%f\n",val_hist[i]);  
  82.         if (temp_max < hist1[i])  
  83.         {  
  84.             temp_max = hist1[i];  
  85.         }  
  86.     }  
  87.     //画直方图  
  88.     CvPoint p1,p2;  
  89.     double bin_width=(double)pic_hist->width/4096;  
  90.     double bin_unith=(double)pic_hist->height/temp_max;  
  91.   
  92.     for (i = 0;i < 4096; i++)  
  93.     {  
  94.         p1.x = i * bin_width;  
  95.         p1.y = pic_hist->height;  
  96.         p2.x = (i + 1)*bin_width;  
  97.         p2.y = pic_hist->height - hist1[i] * bin_unith;  
  98.         //printf("%d,%d,%d,%d\n",p1.x,p1.y,p2.x,p2.y);  
  99.         cvRectangle(pic_hist,p1,p2,cvScalar(0,255,0),-1,8,0);  
  100.     }  
  101.     cvSaveImage("hist1.jpg",pic_hist);  
  102.     cvReleaseImage(&pic_hist);  
  103. }  
  104.   
  105. void MeanShift_Tracking(IplImage *current)  
  106. {  
  107.     int num = 0, i = 0, j = 0;  
  108.     int t_w = 0, t_h = 0, t_x = 0, t_y = 0;  
  109.     double *w = 0, *hist2 = 0;  
  110.     double sum_w = 0, x1 = 0, x2 = 0,y1 = 2.0, y2 = 2.0;  
  111.     int q_r, q_g, q_b;  
  112.     int *q_temp;  
  113.     IplImage *pic_hist = 0;  
  114.   
  115.     t_w = drawing_box.width;  
  116.     t_h = drawing_box.height;  
  117.       
  118.     pic_hist = cvCreateImage(cvSize(300,200),IPL_DEPTH_8U,3);     //生成直方图图像  
  119.     hist2 = (double *)malloc(sizeof(double)*4096);  
  120.     w = (double *)malloc(sizeof(double)*4096);  
  121.     q_temp = (int *)malloc(sizeof(int)*t_w*t_h);  
  122.   
  123.     while ((pow(y2,2) + pow(y1,2) > 0.5)&& (num < NUM))  
  124.     {  
  125.         num++;  
  126.         t_x = drawing_box.x;  
  127.         t_y = drawing_box.y;  
  128.         memset(q_temp,0,sizeof(int)*t_w*t_h);  
  129.         for (i = 0;i<4096;i++)  
  130.         {  
  131.             w[i] = 0.0;  
  132.             hist2[i] = 0.0;  
  133.         }  
  134.   
  135.         for (i = t_y;i < t_h + t_y;i++)  
  136.         {  
  137.             for (j = t_x;j < t_w + t_x;j++)  
  138.             {  
  139.                 //rgb颜色空间量化为16*16*16 bins  
  140.                 q_r = ((u_char)current->imageData[i * current->widthStep + j * 3 + 2]) / 16;  
  141.                 q_g = ((u_char)current->imageData[i * current->widthStep + j * 3 + 1]) / 16;  
  142.                 q_b = ((u_char)current->imageData[i * current->widthStep + j * 3 + 0]) / 16;  
  143.                 q_temp[(i - t_y) *t_w + j - t_x] = q_r * 256 + q_g * 16 + q_b;  
  144.                 hist2[q_temp[(i - t_y) *t_w + j - t_x]] =  hist2[q_temp[(i - t_y) *t_w + j - t_x]] +  m_wei[(i - t_y) * t_w + j - t_x] ;  
  145.             }  
  146.         }  
  147.   
  148.         //归一化直方图  
  149.         for (i=0;i<4096;i++)  
  150.         {  
  151.             hist2[i] = hist2[i] / C;  
  152.             //printf("%f\n",hist2[i]);  
  153.         }  
  154.         //生成目标直方图  
  155.         double temp_max=0.0;  
  156.   
  157.         for (i=0;i<4096;i++)         //求直方图最大值,为了归一化  
  158.         {  
  159.             if (temp_max < hist2[i])  
  160.             {  
  161.                 temp_max = hist2[i];  
  162.             }  
  163.         }  
  164.         //画直方图  
  165.         CvPoint p1,p2;  
  166.         double bin_width=(double)pic_hist->width/(4368);  
  167.         double bin_unith=(double)pic_hist->height/temp_max;  
  168.   
  169.         for (i = 0;i < 4096; i++)  
  170.         {  
  171.             p1.x = i * bin_width;  
  172.             p1.y = pic_hist->height;  
  173.             p2.x = (i + 1)*bin_width;  
  174.             p2.y = pic_hist->height - hist2[i] * bin_unith;  
  175.             cvRectangle(pic_hist,p1,p2,cvScalar(0,255,0),-1,8,0);  
  176.         }  
  177.         cvSaveImage("hist2.jpg",pic_hist);  
  178.       
  179.         for (i = 0;i < 4096;i++)  
  180.         {  
  181.             if (hist2[i] != 0)  
  182.             {  
  183.                 w[i] = sqrt(hist1[i]/hist2[i]);  
  184.             }else  
  185.             {  
  186.                 w[i] = 0;  
  187.             }  
  188.         }  
  189.               
  190.         sum_w = 0.0;  
  191.         x1 = 0.0;  
  192.         x2 = 0.0;  
  193.   
  194.         for (i = 0;i < t_h; i++)  
  195.         {  
  196.             for (j = 0;j < t_w; j++)  
  197.             {  
  198.                 //printf("%d\n",q_temp[i * t_w + j]);  
  199.                 sum_w = sum_w + w[q_temp[i * t_w + j]];  
  200.                 x1 = x1 + w[q_temp[i * t_w + j]] * (i - t_h/2);  
  201.                 x2 = x2 + w[q_temp[i * t_w + j]] * (j - t_w/2);  
  202.             }  
  203.         }  
  204.         y1 = x1 / sum_w;  
  205.         y2 = x2 / sum_w;  
  206.           
  207.         //中心点位置更新  
  208.         drawing_box.x += y2;  
  209.         drawing_box.y += y1;  
  210.   
  211.         //printf("%d,%d\n",drawing_box.x,drawing_box.y);  
  212.     }  
  213.     free(hist2);  
  214.     free(w);  
  215.     free(q_temp);  
  216.     //显示跟踪结果  
  217.     cvRectangle(current,cvPoint(drawing_box.x,drawing_box.y),cvPoint(drawing_box.x+drawing_box.width,drawing_box.y+drawing_box.height),CV_RGB(255,0,0),2);  
  218.     cvShowImage("Meanshift",current);  
  219.     //cvSaveImage("result.jpg",current);  
  220.     cvReleaseImage(&pic_hist);  
  221. }  
  222.   
  223. void onMouse( int event, int x, int y, int flags, void *param )  
  224. {  
  225.     if (pause)  
  226.     {  
  227.         switch(event)  
  228.         {  
  229.         case CV_EVENT_LBUTTONDOWN:   
  230.             //the left up point of the rect  
  231.             drawing_box.x=x;  
  232.             drawing_box.y=y;  
  233.             break;  
  234.         case CV_EVENT_LBUTTONUP:  
  235.             //finish drawing the rect (use color green for finish)  
  236.             drawing_box.width=x-drawing_box.x;  
  237.             drawing_box.height=y-drawing_box.y;  
  238.             cvRectangle(current,cvPoint(drawing_box.x,drawing_box.y),cvPoint(drawing_box.x+drawing_box.width,drawing_box.y+drawing_box.height),CV_RGB(255,0,0),2);  
  239.             cvShowImage("Meanshift",current);  
  240.               
  241.             //目标初始化  
  242.             hist1 = (double *)malloc(sizeof(double)*16*16*16);  
  243.             m_wei =  (double *)malloc(sizeof(double)*drawing_box.height*drawing_box.width);  
  244.             init_target(hist1, m_wei, current);  
  245.             is_tracking = true;  
  246.             break;  
  247.         }  
  248.         return;  
  249.     }  
  250. }  
  251.   
  252.   
  253.   
  254. int _tmain(int argc, _TCHAR* argv[])  
  255. {  
  256.     CvCapture *capture=cvCreateFileCapture("test.avi");  
  257.     current = cvQueryFrame(capture);  
  258.     char res[20];  
  259.     int nframe = 0;  
  260.   
  261.     while (1)  
  262.     {     
  263.     /*  sprintf(res,"result%d.jpg",nframe); 
  264.         cvSaveImage(res,current); 
  265.         nframe++;*/  
  266.         if(is_tracking)  
  267.         {  
  268.             MeanShift_Tracking(current);  
  269.         }  
  270.   
  271.         int c=cvWaitKey(1);  
  272.         //暂停  
  273.         if(c == 'p')   
  274.         {  
  275.             pause = true;  
  276.             cvSetMouseCallback( "Meanshift", onMouse, 0 );  
  277.         }  
  278.         while(pause){  
  279.             if(cvWaitKey(0) == 'p')  
  280.                 pause = false;  
  281.         }  
  282.         cvShowImage("Meanshift",current);  
  283.         current = cvQueryFrame(capture); //抓取一帧  
  284.     }  
  285.   
  286.     cvNamedWindow("Meanshift",1);  
  287.     cvReleaseCapture(&capture);  
  288.     cvDestroyWindow("Meanshift");  
  289.     return 0;  
  290. }  

运行结果:

机器学习-Mean shift算法详解和实现_第15张图片 机器学习-Mean shift算法详解和实现_第16张图片

初始目标直方图:

机器学习-Mean shift算法详解和实现_第17张图片

候选目标直方图:

机器学习-Mean shift算法详解和实现_第18张图片

你可能感兴趣的:(算法,机器学习,shift,mean)