k均值聚类python实现

K均值聚类(K-Means Clustering)是一种常用的无监督学习算法,用于将数据分成K个簇。以下是一个简单的Python实现K均值聚类的代码讲解,包括数据准备、初始化、迭代更新簇心和分配簇标签等步骤。
CSDN大礼包:《2025年最新全套学习资料包》免费分享
k均值聚类python实现_第1张图片

代码实现

import numpy as np
import matplotlib.pyplot as plt

# 生成示例数据
np.random.seed(42)
X = np.random.rand(100, 2)  # 生成100个二维数据点

# 设置簇的数量
k = 3

# 初始化簇心(随机选择k个点作为簇心)
centroids = X[np.random.choice(X.shape[0], k, replace=False)]

# 设置最大迭代次数和容忍度
max_iters = 100
tolerance = 1e-4

# 存储每次迭代的簇心变化量
centroid_shifts = []

# K均值聚类算法
for _ in range(max_iters):
    # 计算每个点到所有簇心的距离,并分配最近的簇心标签
    distances = np.linalg.norm(X[:, np.newaxis] - centroids, axis=2)
    labels = np.argmin(distances, axis=1)
    
    # 计算新的簇心(簇内所有点的均值)
    new_centroids = np.array([X[labels == i].mean(axis=0) for i in range(k)])
    
    # 检查簇心是否收敛(变化量是否小于容忍度)
    centroid_shift = np.linalg.norm(new_centroids - centroids)
    centroid_shifts.append(centroid_shift)
    
    # 更新簇心
    centroids = new_centroids
    
    # 如果簇心收敛,则提前退出循环
    if centroid_shift < tolerance:
        break

# 可视化结果
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis')
plt.scatter(centroids[:, 0], centroids[:, 1], s=300, c='red', marker='X')  # 绘制簇心
plt.title('K-Means Clustering')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()

# 打印迭代次数和最终簇心变化量
print(f"Number of iterations: {_ + 1}")
print(f"Final centroid shift: {centroid_shift}")

代码讲解

  1. 数据准备

    • 使用np.random.rand(100, 2)生成100个二维数据点作为示例数据。
  2. 初始化

    • 设置簇的数量k=3
    • 使用np.random.choice随机选择k个点作为初始簇心。
  3. 迭代更新

    • 使用一个for循环进行最多max_iters=100次迭代。
    • 在每次迭代中,首先计算每个数据点到所有簇心的欧氏距离,并使用np.argmin找到最近的簇心标签。
    • 然后,计算新的簇心,即簇内所有点的均值。
    • 检查簇心的变化量是否小于容忍度tolerance=1e-4,如果是,则提前退出循环。
    • 更新簇心为新的簇心。
  4. 可视化结果

    • 使用matplotlib绘制数据点和簇心。数据点根据其簇标签着色,簇心用红色“X”标记。
  5. 打印信息

    • 打印实际迭代次数和最终的簇心变化量。

注意事项

  • 初始簇心的选择对K均值聚类的结果有很大影响。可以使用更复杂的初始化方法(如K均值++)来改善结果。
  • 簇的数量k是一个超参数,需要根据数据的实际情况和先验知识来选择。
  • K均值聚类对异常值(离群点)很敏感,因为簇心是簇内所有点的均值。
  • 在实际应用中,可能需要多次运行K均值聚类算法,每次使用不同的初始簇心,以找到更好的聚类结果。

你可能感兴趣的:(均值算法,聚类,python,开发语言,Python基础)