【数学与算法】KMeans聚类代码

KMeans聚类是根据各点距离聚类中心的距离来把所有点分类到不同类别的无监督算法。

对于聚类,就是两点:

  • 1.分类所有样本点:遍历每个数据样本点,分别计算该样本点与K个聚类中心的距离,把该样本点的类别重新分类为距离最小的那一类。
  • 2.更新聚类中心:所有样本点都按第一步重新分类后,把各类别的点重新计算聚类中心(求平均值的方法),更新K个类别的聚类中心值。
  • 3.重复前面两步,直到聚类中心点更新幅度小于阈值,或者达到迭代次数,或者所有样本点的类别都不再改变,或者他们几者组合起来,就停止迭代。

它适合分类一堆一堆的点:见下图中左边的三堆点。

不适合对几条曲线组成的点进行分类,见下图的右边三条线

以一条曲线的起点和终点为例:一条曲线特别长,他的起点和终点之间的距离可能也会特别大,因此,通过欧氏距离进行聚类的话,会出现别的曲线上的点更接近他的起点和终点的情况,那么起点和终点可能会和其他曲线很靠近的点聚类成一类。因此最终的分类效果肯定是很差。
所以,理解了聚类的原理,就知道了他的适用范围,不会在不能使用聚类的地方尝试使用聚类方法。

【数学与算法】KMeans聚类代码_第1张图片


代码:

下面例子用kmeans分类一系列三维空间点。
头文件:

#pragma once
#include 
// #include 
#include 

struct Point_3D {
  float x;
  float y;
  float z;
  Point_3D operator=(Point_3D point) {
    x = point.x;
    y = point.y;
    z = point.z;
  }
};
typedef std::vector<Point_3D> Point3DVct;

class KMeans {
 public:
  int m_k;  // k个类别

  Point3DVct input_point3D_vct_;          //要聚类的点云
  std::vector<Point3DVct> k_points_vct_;  // K类,每一类存储若干点
  Point3DVct k_center_point_vct_;         //每个类的中心

  KMeans() { m_k = 0; }

  inline void SetK(int k_) {
    m_k = k_;
    k_points_vct_.resize(m_k);
  }
  //设置输入点
  bool SetInput(const Point3DVct &input_points, Point3DVct &o_points);

  //初始化最初的K个类的中心
  bool InitKCenter(Point3DVct &K_center_point_vct);

  //聚类
  bool Cluster(const Point3DVct &input_points,
               std::vector<Point3DVct> &k_points_vct);

  //更新K类的中心
  bool UpdateGroupCenter(std::vector<Point3DVct> &K_points_vct,
                         Point3DVct &centers);

  //计算两个点间的欧氏距离
  float DistBetweenPoints(const Point_3D &p1, const Point_3D &p2);

  //是否存在中心点移动,用来判断分类结果是否已收敛
  bool ExistCenterShift(Point3DVct &prev_center, Point3DVct &cur_center);
};

源文件:

#include "k_means.h"

#include 
// #include 
#include 
#include 

#include 

const float DELTA = 0.001;

bool KMeans::InitKCenter(Point3DVct &K_center_point_vct) {
  if (m_k == 0) {
    std::cout << "在此之前必须要调用setK()函数" << std::endl;
    return false;
  }

  k_center_point_vct_.resize(m_k);
  for (size_t i = 0; i < m_k; ++i) {
    k_center_point_vct_[i] = K_center_point_vct[i];
  }
  return true;
}

bool KMeans::SetInput(const Point3DVct &input_points, Point3DVct &o_points) {
  for (int i = 0; i < input_points.size(); ++i) {
    Point_3D p = input_points[i];
    o_points.push_back(p);
  }
  return true;
}

bool KMeans::Cluster(const Point3DVct &input_points,
                     std::vector<Point3DVct> &k_points_vct) {
  Point3DVct input_point3D_vct;
  SetInput(input_points, input_point3D_vct);

  Point3DVct v_center(k_center_point_vct_.size());

  do {
    for (size_t i = 0, pntCount = input_point3D_vct.size(); i < pntCount; ++i) {
      float min_dist = 10000000000;
      int point_class = 0;
      for (size_t j = 0; j < m_k; ++j) {
        float dist =
            DistBetweenPoints(input_point3D_vct[i], k_center_point_vct_[j]);
        if (min_dist - dist > 0.000001) {
          min_dist = dist;
          point_class = j;
        }
      }
      k_points_vct_[point_class].push_back(input_point3D_vct[i]);
    }

    //保存上一次迭代的中心点
    for (size_t i = 0; i < k_center_point_vct_.size(); ++i) {
      v_center[i] = k_center_point_vct_[i];
    }

    if (!UpdateGroupCenter(k_points_vct_, k_center_point_vct_)) {
      return false;
    }
    if (!ExistCenterShift(v_center, k_center_point_vct_)) {
      k_points_vct = k_points_vct_;
      break;
    }

    for (size_t i = 0; i < m_k; ++i) {
      for (int j = 0; j < k_points_vct_[i].size(); ++j) {
        const Point_3D &p = k_points_vct_[i][j];
        std::cout << "x= " << p.x << ",   y= " << p.y << ",   z= " << p.z
                  << " ,class: " << i << std::endl;
      }
    }
    std::cout << "--------------------- " << std::endl;
    for (size_t i = 0; i < m_k; ++i) {
      k_points_vct_[i].clear();
    }

  } while (true);

  return true;
}

// 计算两个点之间的距离
float KMeans::DistBetweenPoints(const Point_3D &p1, const Point_3D &p2) {
  float dist = 0;
  float x_diff = 0, y_diff = 0, z_diff = 0;

  x_diff = p1.x - p2.x;
  y_diff = p1.y - p2.y;
  z_diff = p1.z - p2.z;
  dist = sqrt(x_diff * x_diff + y_diff * y_diff + z_diff * z_diff);

  return dist;
}

bool KMeans::UpdateGroupCenter(std::vector<Point3DVct> &K_points_vct,
                               Point3DVct &centers) {
  if (centers.size() != m_k) {
    std::cout << "类别的个数不为K" << std::endl;
    return false;
  }

  for (size_t i = 0; i < m_k; ++i) {
    float x = 0, y = 0, z = 0;
    size_t point_num_in_this_class = K_points_vct[i].size();

    // 遍历每个类别的数据,每次遍历都把一类数据的x全加起来,求平均数,赋值给该类别的中心x;
    // y全加起来,求平均数,赋值给该类别的中心y;
    // z全加起来,求平均数,赋值给该类别的中心z
    for (size_t j = 0; j < point_num_in_this_class; ++j) {
      x += K_points_vct[i][j].x;
      y += K_points_vct[i][j].y;
      z += K_points_vct[i][j].z;
    }
    x /= point_num_in_this_class;
    y /= point_num_in_this_class;
    z /= point_num_in_this_class;
    centers[i].x = x;
    centers[i].y = y;
    centers[i].z = z;
  }
  return true;
}

//是否存在中心点移动
// 就是说遍历K个类别的中心点,若上一次和本次更新的中心点距离变化大于一定值就表示正在更新更新了;
// 否则,就表示不再更新迭代停止;
// 只要有一个返回值大于阈值,就表示有数据更新,不能停止迭代。如果所有个类别的中心距离都小于某阈值,就表示更新停止.
bool KMeans::ExistCenterShift(Point3DVct &prev_center, Point3DVct &cur_center) {
  for (size_t i = 0; i < m_k; ++i) {
    float dist = DistBetweenPoints(prev_center[i], cur_center[i]);
    if (dist > DELTA) {
      return true;
    }
  }
  return false;
}

你可能感兴趣的:(数学和算法,聚类)