K-Means聚类算法的实现

转载地址: http://blog.csdn.net/lming_08/article/details/20778351 

K-Means算法简介

K-Means算法是一种常用的聚类算法,因其思想简单、容易实现而收到广泛的运用。其思想大概是从要聚类的样本中选取K个样本,然后遍历所有样本,对每个样本计算其与K个样本间的距离(可以为欧氏距离或余弦距离),然后将其类别归为距离最小的样本所属类别,这样的话,所有样本就都找到各自所属的类别;然后分别重新计算K个类别中样本的质心;之后返回第一步继续迭代执行,如此直到K个类别中样本的质心不再移动或移动的非常小。整个过程往往要不了几次就达到收敛。

基于PCL库对三维空间点的K-Means聚类算法的实现

在三维点云处理中我们经常要对点云进行聚类分割处理,如建筑物与地面、桌面与水杯等的分割,以便于我们可以在后续三维重建中得到更好的效果。这时比较好的聚类方法有欧式聚类和K-Means聚类。这里简要地介绍下基于PCL库对三维空间点的K-Means聚类算法的实现。

相关头文件common.h中部分内容

[cpp]  view plain  copy
 print ?
  1. //笛卡尔坐标系中三维点坐标  
  2. typedef struct st_pointxyz  
  3. {  
  4.     float x;  
  5.     float y;  
  6.     float z;  
  7. }st_pointxyz;  
  8. typedef struct st_point  
  9. {  
  10.     st_pointxyz pnt;  
  11.     int groupID;  
  12.     st_point()  
  13.     {  
  14.   
  15.     }  
  16.     st_point(st_pointxyz &p, int id)  
  17.     {  
  18.         pnt = p;  
  19.         groupID = id;  
  20.     }  
  21. }st_point;  
  22.   
  23. class KMeans  
  24. {  
  25. public:  
  26.     int m_k;  
  27.   
  28.     typedef std::vector VecPoint_t;  
  29.       
  30.     VecPoint_t mv_pntcloud;    //要聚类的点云  
  31.     std::vector m_grp_pntcloud;    //K类,每一类存储若干点  
  32.     std::vector mv_center;    //每个类的中心  
  33.   
  34.     KMeans()  
  35.     {  
  36.         m_k = 0;  
  37.     }  
  38.   
  39.     inline void SetK(int k_)  
  40.     {  
  41.         m_k = k_;  
  42.         m_grp_pntcloud.resize(m_k);  
  43.     }  
  44.     //设置输入点云  
  45.     bool SetInputCloud(PointCloud::Ptr pPntCloud);  
  46.   
  47.     //初始化最初的K个类的中心  
  48.     bool InitKCenter(st_pointxyz pc_arr[]);  
  49.   
  50.     //聚类  
  51.     bool Cluster();  
  52.   
  53.     //更新K类的中心  
  54.     bool UpdateGroupCenter(std::vector &grp_pntcloud, std::vector ¢er);  
  55.   
  56.     //计算两个点间的欧氏距离  
  57.     double DistBetweenPoints(st_pointxyz &p1, st_pointxyz &p2);  
  58.       
  59.     //是否存在中心点移动  
  60.     bool ExistCenterShift(std::vector &prev_center, std::vector &cur_center);  
  61.   
  62.     //将聚类的点分别存到各自的pcd文件中  
  63.     bool SaveFile(const char *prex_name);  
  64.     //将聚类的点分别存到各自的pcd文件中  
  65.     bool SaveFile(const char *dir_name, const char *prex_name);  
  66. };  

实现文件kmeans.cpp中内容为:

[cpp]  view plain  copy
 print ?
  1. #include "common.h"  
  2.   
  3. const float DIST_NEAR_ZERO = 0.001;  
  4.   
  5. extern char szFileName[256];  
  6.   
  7. bool KMeans::InitKCenter(st_pointxyz pnt_arr[])  
  8. {  
  9.     if (m_k == 0)  
  10.     {  
  11.         PCL_ERROR("在此之前必须要调用setK()函数\n");  
  12.         return false;  
  13.     }  
  14.   
  15.     mv_center.resize(m_k);  
  16.     for (size_t i = 0; i < m_k; ++i)  
  17.     {  
  18.         mv_center[i] = pnt_arr[i];  
  19.     }  
  20.     return true;  
  21. }  
  22.   
  23. bool KMeans::SetInputCloud(PointCloud::Ptr pPntCloud)  
  24. {  
  25.     size_t pntCount = (size_t)pPntCloud->points.size();  
  26.     //mv_pntcloud.resize(pntCount);  
  27.     for (size_t i = 0; i < pntCount; ++i)  
  28.     {  
  29.         st_point point;  
  30.         point.pnt.x = pPntCloud->points[i].x;  
  31.         point.pnt.y = pPntCloud->points[i].y;  
  32.         point.pnt.z = pPntCloud->points[i].z;  
  33.         point.groupID = 0;  
  34.   
  35.         mv_pntcloud.push_back(point);  
  36.     }  
  37.   
  38.     return true;  
  39. }  
  40.   
  41. bool KMeans::Cluster()  
  42. {  
  43.     std::vector v_center(mv_center.size());  
  44.   
  45.     do  
  46.     {  
  47.         for (size_t i = 0, pntCount = mv_pntcloud.size(); i < pntCount; ++i)  
  48.         {  
  49.             double min_dist = DBL_MAX;  
  50.             int pnt_grp = 0;  
  51.             for (size_t j = 0; j < m_k; ++j)  
  52.             {  
  53.                 double dist = DistBetweenPoints(mv_pntcloud[i].pnt, mv_center[j]);  
  54.                 if (min_dist - dist > 0.000001)  
  55.                 {  
  56.                     min_dist = dist;  
  57.                     pnt_grp = j;  
  58.                 }  
  59.             }  
  60.             m_grp_pntcloud[pnt_grp].push_back(st_point(mv_pntcloud[i].pnt, pnt_grp));  
  61.         }  
  62.   
  63.         //保存上一次迭代的中心点  
  64.         for (size_t i = 0; i < mv_center.size(); ++i)  
  65.         {  
  66.             v_center[i] = mv_center[i];  
  67.         }  
  68.   
  69.         if (!UpdateGroupCenter(m_grp_pntcloud, mv_center))  
  70.         {  
  71.             return false;  
  72.         }  
  73.         if ( !ExistCenterShift(v_center, mv_center))  
  74.         {  
  75.             break;  
  76.         }  
  77.         for (size_t i = 0; i < m_k; ++i){  
  78.             m_grp_pntcloud[i].clear();  
  79.         }  
  80.   
  81.     }while(true);  
  82.   
  83.     return true;  
  84. }  
  85.   
  86. double KMeans::DistBetweenPoints(st_pointxyz &p1, st_pointxyz &p2)  
  87. {  
  88.     double dist = 0;  
  89.     double x_diff = 0, y_diff = 0, z_diff = 0;  
  90.   
  91.     x_diff = p1.x - p2.x;  
  92.     y_diff = p1.y - p2.y;  
  93.     z_diff = p1.z - p2.z;  
  94.     dist = sqrt(x_diff * x_diff + y_diff * y_diff + z_diff * z_diff);  
  95.       
  96.     return dist;  
  97. }  
  98.   
  99. bool KMeans::UpdateGroupCenter(std::vector &grp_pntcloud, std::vector ¢er)  
  100. {  
  101.     if (center.size() != m_k)  
  102.     {  
  103.         PCL_ERROR("类别的个数不为K\n");  
  104.         return false;  
  105.     }  
  106.   
  107.     for (size_t i = 0; i < m_k; ++i)  
  108.     {  
  109.         float x = 0, y = 0, z = 0;  
  110.         size_t pnt_num_in_grp = grp_pntcloud[i].size();  
  111.   
  112.         for (size_t j = 0; j < pnt_num_in_grp; ++j)  
  113.         {             
  114.             x += grp_pntcloud[i][j].pnt.x;  
  115.             y += grp_pntcloud[i][j].pnt.y;  
  116.             z += grp_pntcloud[i][j].pnt.z;  
  117.         }  
  118.         x /= pnt_num_in_grp;  
  119.         y /= pnt_num_in_grp;  
  120.         z /= pnt_num_in_grp;  
  121.         center[i].x = x;  
  122.         center[i].y = y;  
  123.         center[i].z = z;  
  124.     }  
  125.     return true;  
  126. }  
  127.   
  128. //是否存在中心点移动  
  129. bool KMeans::ExistCenterShift(std::vector &prev_center, std::vector &cur_center)  
  130. {  
  131.     for (size_t i = 0; i < m_k; ++i)  
  132.     {  
  133.         double dist = DistBetweenPoints(prev_center[i], cur_center[i]);  
  134.         if (dist > DIST_NEAR_ZERO)  
  135.         {  
  136.             return true;  
  137.         }  
  138.     }  
  139.   
  140.     return false;  
  141. }  
  142.   
  143. //将聚类的点分别存到各自的pcd文件中  
  144. bool KMeans::SaveFile(const char *prex_name)  
  145. {  
  146.     for (size_t i = 0; i < m_k; ++i)  
  147.     {  
  148.         pcl::PointCloud::Ptr p_pnt_cloud(new pcl::PointCloud ());  
  149.   
  150.         for (size_t j = 0, grp_pnt_count = m_grp_pntcloud[i].size(); j < grp_pnt_count; ++j)  
  151.         {  
  152.             pcl::PointXYZ pt;  
  153.             pt.x = m_grp_pntcloud[i][j].pnt.x;  
  154.             pt.y = m_grp_pntcloud[i][j].pnt.y;  
  155.             pt.z = m_grp_pntcloud[i][j].pnt.z;  
  156.   
  157.             p_pnt_cloud->points.push_back(pt);  
  158.         }  
  159.   
  160.         p_pnt_cloud->width = (int)m_grp_pntcloud[i].size();  
  161.         p_pnt_cloud->height = 1;  
  162.   
  163.         char newFileName[256] = {0};  
  164.         char indexStr[16] = {0};  
  165.   
  166.         strcat(newFileName, szFileName);  
  167.         strcat(newFileName, "-");  
  168.         strcat(newFileName, prex_name);  
  169.         strcat(newFileName, "-");  
  170.         sprintf(indexStr, "%d", i + 1);  
  171.         strcat(newFileName, indexStr);  
  172.         strcat(newFileName, ".pcd");  
  173.         savePCDFileASCII(newFileName, *p_pnt_cloud);  
  174.     }  
  175.       
  176.     return true;  
  177. }  
  178.   
  179. bool KMeans::SaveFile(const char *dir_name, const char *prex_name)  
  180. {  
  181.     for (size_t i = 0; i < m_k; ++i)  
  182.     {  
  183.         pcl::PointCloud::Ptr p_pnt_cloud(new pcl::PointCloud ());  
  184.   
  185.         for (size_t j = 0, grp_pnt_count = m_grp_pntcloud[i].size(); j < grp_pnt_count; ++j)  
  186.         {  
  187.             pcl::PointXYZ pt;  
  188.             pt.x = m_grp_pntcloud[i][j].pnt.x;  
  189.             pt.y = m_grp_pntcloud[i][j].pnt.y;  
  190.             pt.z = m_grp_pntcloud[i][j].pnt.z;  
  191.   
  192.             p_pnt_cloud->points.push_back(pt);  
  193.         }  
  194.   
  195.         p_pnt_cloud->width = (int)m_grp_pntcloud[i].size();  
  196.         p_pnt_cloud->height = 1;  
  197.   
  198.         char newFileName[256] = {0};  
  199.         char indexStr[16] = {0};  
  200.   
  201.         strcat(newFileName, dir_name);  
  202.         strcat(newFileName, "/");  
  203.         strcat(newFileName, prex_name);  
  204.         strcat(newFileName, "-");  
  205.         sprintf(indexStr, "%d", i + 1);  
  206.         strcat(newFileName, indexStr);  
  207.         strcat(newFileName, ".pcd");  
  208.         savePCDFileASCII(newFileName, *p_pnt_cloud);  
  209.     }  
  210.   
  211.     return true;  
  212. }  
下面编写测试用例,测试效果:

构造一个以(0, 0, 0)为球心,半径为2的球体;一个左下角坐标为(2.5, 2.5, 2.5),棱长为2的正方体;一个圆心为(1, 1, -3),半径为1的圆。然后初始类的中心分别为上述三个体的中心,并执行K-Means聚类算法,将聚类后的点云数据分别保存到对应的文件中。代码如下:

[cpp]  view plain  copy
 print ?
  1. void test_kmeans_manual_consdata()  
  2. {  
  3.     //构造球体  
  4.     float radius = 2;    
  5.     for (float r = 0; r < radius; r += 0.1)  
  6.     {  
  7.         for (float angle1 = 0.0; angle1 <= 180.0; angle1 += 5.0)    
  8.         {    
  9.             for (float angle2 = 0.0; angle2 <= 360.0; angle2 += 5.0)    
  10.             {    
  11.                 pcl::PointXYZ basic_point;    
  12.   
  13.                 basic_point.x = radius * sinf(pcl::deg2rad(angle1)) * cosf(pcl::deg2rad(angle2));    
  14.                 basic_point.y = radius * sinf(pcl::deg2rad(angle1)) * sinf(pcl::deg2rad(angle2));    
  15.                 basic_point.z = radius * cosf(pcl::deg2rad(angle1));    
  16.                 cloud->points.push_back(basic_point);    
  17.             }    
  18.         }  
  19.     }  
  20.   
  21.     //构造立方体  
  22.     float cube_len = 2;  
  23.     for (float x = 0; x < cube_len; x += 0.1)  
  24.     {  
  25.         for (float y = 0; y < cube_len; y += 0.1)  
  26.         {  
  27.             for (float z = 0; z < cube_len; z += 0.1)  
  28.             {  
  29.                 pcl::PointXYZ basic_point;    
  30.   
  31.                 //沿着向量(2.5, 2.5, 2.5)平移  
  32.                 basic_point.x = x + 2.5;    
  33.                 basic_point.y = y + 2.5;    
  34.                 basic_point.z = z + 2.5;    
  35.                 cloud->points.push_back(basic_point);    
  36.             }  
  37.         }  
  38.     }  
  39.   
  40.     //构造圆形平面  
  41.     float R = 1;  
  42.     for (float radius = 0; radius < R; radius += 0.05)  
  43.     {  
  44.         for (float r = 0; r < radius; r += 0.05)  
  45.         {  
  46.             for (float ang = 0; ang <= 360.0; ang += 5.0)  
  47.             {  
  48.                 pcl::PointXYZ basic_point;    
  49.   
  50.                 basic_point.x = radius * sinf(pcl::deg2rad(ang)) +3;    
  51.                 basic_point.y = radius * cosf(pcl::deg2rad(ang)) + 3;    
  52.                 basic_point.z = -3;    
  53.                 cloud->points.push_back(basic_point);    
  54.             }  
  55.         }  
  56.     }  
  57.   
  58.     cloud->width = (int)cloud->points.size();    
  59.     cloud->height = 1;  
  60.   
  61.     //开始KMeans聚类  
  62.     KMeans kmeans;  
  63.     st_pointxyz center_arr[3] = {  
  64.         {0, 0, 0},  
  65.         {2.5, 2.5, 2.5},   
  66.         {3, 3, -3}  
  67.     };  
  68.   
  69.     kmeans.SetInputCloud(cloud);  
  70.     kmeans.SetK(3);  
  71.     kmeans.InitKCenter(center_arr);  
  72.     kmeans.Cluster();  
  73.     kmeans.SaveFile(".""k3");  
  74. }  
执行完后可以看到生成了三个文件k3-1.pcd、k3-2.pcd、k3-3.pcd,用pcd_viewer_release.exe工具打开这三个文件得到:

可以看到聚类效果还是不错的。

但是以上实现的K-Means聚类算法有时候效果就不是很好,例如,将上述圆的位置移到圆心为(1, 1, -3)处时,得到的效果却是这样的:

在这几天在工作中也碰到了K-Means聚类效果不太好的情况,点云为某教学楼前的一个环形路面,聚类之前的空间三维点分布情况如图所示:

选取K = 10后,聚类后的效果如下所示:

可以看到效果与期望值相差的有些离谱。

从以上两个例子中可以看到效果不太好的原因就是期望的一类A所形成的体积较大,且A类边缘点到中心的距离较大,如果其中A类的旁边(距离较近)存在另一类B且B类的体积较小,那么期望的一类A将会被分割,造成效果不好。

总之,对于具体的数据,我们要选取恰当的方法来聚类。


文章参考于:http://www.cnblogs.com/jerrylead/archive/2011/04/06/2006910.html

1

你可能感兴趣的:(算法与数据结构)