手工实现k-means算法

手工实现k-means算法

手工实现常见聚类算法–k-means(k均值)


手工实现,不使用现成的聚类库,如sklearn;但是使用了numpy和绘图工具类matplotlib???不知道算不算纯手工。

文章目录

  • 手工实现k-means算法
  • 一、k-means是什么?
  • 二、步骤
    • 1.引入库
    • 2.读入数据
    • 3.聚类处理
      • 初始化k个中心点
      • 判断该数据点属于哪个中心点
      • 重新计算中心点位置
      • #k-means
      • 测试用例
    • 4.用到的函数
      • np.loadtxt()
      • x[:,n]
      • np.random.randint()
      • np.zeros(len(X)).reshape( X.shape[0],-1)
      • power(x, y)
      • np.ravel(idx).tolist()
      • set()
      • np.arange()
      • a.ravel()
      • np.where()[0]
      • np.sum (a,axis = 0)
    • 绘图
      • mpl.colors.ListedColormap(['g', 'r', 'b'])
      • plt.scatter()
      • plt.show()
  • 总结


一、k-means是什么?

K-均值是最普及的聚类算法,算法接受一个未标记的数据集,然后将数据聚类成不同的组。

K-均值是一个迭代算法,假设我们想要将数据聚类成 n 个组,其方法为:

  • 首先选择个随机的点,称为聚类中心(cluster centroids);
  • 对于数据集中的每一个数据,按照距离个中心点的距离,将其与距离最近的中心点关联起来,与同一个中心点关联的所有点聚成一类。
  • 计算每一个组的平均值,将该组所关联的中心点移动到平均值的位置。
  • 重复步骤,直至中心点不再变化或达到最大迭代次数(本次使用)。

二、步骤

1.引入库

NumPy 是一个运行速度非常快的数学库,主要用于数组计算,包含:
一个强大的N维数组对象 ndarray
广播功能函数
整合 C/C++/Fortran 代码的工具
线性代数、傅里叶变换、随机数生成等功能

NumPy通常与 SciPy(Scientific Python)和 Matplotlib(绘图库)一起使用, 这种组合广泛用于替代 MatLab,是一个强大的科学计算环境,有助于我们通过 Python 学习数据科学或者机器学习。

SciPy 是一个开源的 Python 算法库和数学工具包。
SciPy 包含的模块有最优化、线性代数、积分、插值、特殊函数、快速傅里叶变换、信号处理和图像处理、常微分方程求解和其他科学与工程中常用的计算。

Matplotlib 是 Python 编程语言及其数值数学扩展包 NumPy 的可视化操作界面。它为利用通用的图形用户界面工具包,如 Tkinter, wxPython, Qt 或 GTK+ 向应用程序嵌入式绘图提供了应用程序接口(API)。

代码如下(示例):

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

2.读入数据

代码如下(示例):

def loaddata():
    data = np.loadtxt('./watermelon_4.txt',delimiter=',')
    return data

X = loaddata()
plt.scatter(X[:, 0], X[:, 1], s=20)
plt.show()

手工实现k-means算法_第1张图片


3.聚类处理

初始化k个中心点

#随机初始化中心点
def kMeansInitCentroids(X, k):
    #从X的数据中随机取k个作为中心点
    index=np.random.randint(0,len(X-1),k)
    return X[index]

判断该数据点属于哪个中心点

#计算数据点到中心点的距离,并判断该数据点属于哪个中心点
def findClosestCentroids(X, centroids):
    #idx中数据表明对应X的数据是属于哪一个中心点的,创建数组idx
    idx = np.zeros(len(X))
    for i in range(len(X)):
        #补充计算数据点到中心点的距离,并判断该数据点所属中心点
        #求最小值,初始化为正无穷
        minDistance=float('inf')
        index=0
        for k in range(len(centroids)):
            #直线距离的平方
            distance=np.sum(np.power(X[i]-centroids[k],2))
            if(distance<minDistance):
                minDistance=distance
                index=k
                #距离第k个中心点距离最小
        idx[i]=index
    return idx

重新计算中心点位置

#重新计算中心点位置
def computeCentroids(X, idx):
    k = set(np.ravel(idx).tolist()) #找到所有聚类中心索引
    k = list(k)
    #N 维数组对象 ndarray,它是一系列同类型数据的集合,以 0 下标为开始进行集合中元素的索引。创建一个ndarray对象:
    #len(k),X.shape[1]分别为行数和列数
    centroids = np.ndarray((len(k),X.shape[1]))
    
    for i in range(len(k)):
    #选择数据X中类别为k[i]的数据,按行选择
        data = X[np.where(idx==k[i])[0]]
        #重新计算聚类中心,axis=0为行压缩,将矩阵的每列相加形成一个一维数组,除以原矩阵行数得到聚类中心
        centroids[i] = np.sum(data,axis=0)/len(data) 
    return centroids

#k-means

def k_means(X, k, max_iters):
    initial_centroids = kMeansInitCentroids(X,k)
    #迭代
    for i in range(max_iters):
        
        if(i==0):
            centroids=initial_centroids
            #print(centroids)
            
        #计算样本到聚类中心的距离,并返回每个样本所属的聚类中心
        idx=findClosestCentroids(X,centroids)
        #重新计算聚类中心
        centroids=computeCentroids(X,idx)
    return idx,centroids

测试用例

#测试用例1
idx,centroids = k_means(X, 3, 8)
print(idx)
print(centroids)

手工实现k-means算法_第2张图片

#绘图
cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
#绘制样本点
plt.scatter(X[:, 0], X[:, 1], c=np.ravel(idx), cmap=cm_dark, s=20)
#绘制中心点
plt.scatter(centroids[:, 0], centroids[:, 1], c=np.arange(len(centroids)), cmap=cm_dark, marker='*', s=500)
plt.show()

手工实现k-means算法_第3张图片

4.用到的函数

np.loadtxt()

从文本加载数据。文本文件中的每一行必须含有相同的数据。


x[:,n]

表示在全部数组(维)中取第n个数据,直观来说,x[:,n]就是取所有集合的第n个数据,

np.random.randint()

返回k个范围内的随机整数。

np.zeros(len(X)).reshape( X.shape[0],-1)

**numpy.zeros(shape,dtype=float,order = ‘C’)**返回给定形状和类型的新数组,用0填充。

https://blog.csdn.net/lens___/article/details/83927880

**reshape()**函数用于在不更改数据的情况下为数组赋予新形状。两个参数为几行几列,-1(负数)是模糊控制,根据另一个维度自动算好。

https://blog.csdn.net/weixin_43937759/article/details/106605680

X.shape[0]:numpy 创建的数组都有一个shape属性,它是一个元组,返回各个维度的维数

https://www.cnblogs.com/wanglinjie/p/11761779.html


power(x, y)

计算 x 的 y 次方。

np.ravel(idx).tolist()

ravel(idx)将多维数组转化成一维数组
tolist()将数组或矩阵转化成列表

set()

创建一个无序不重复元素集,可进行关系测试,删除重复数据,还可以计算交集、差集、并集等。

https://www.runoob.com/python/python-func-set.html

np.arange()

函数返回一个有终点和起点的固定步长的排列,如[1,2,3,4,5],起点是1,终点是6,步长为1。

https://blog.csdn.net/qq_41550480/article/details/89390579

a.ravel()

ravel()方法将数组维度拉成一维数组

np.where()[0]

np.where()[0] 表示行索引,np.where()[1]表示列索引

np.sum (a,axis = 0)

axis为0是压缩行,即将每一列的元素相加,将矩阵压缩为一行,输出:array ([34., 38., 42., 46., 50.]) np.sum (a,axis = 1) axis为1是压缩列,即将每一行的元素相加,将矩阵压缩为一列,输出:array ([15., 40., 65., 90.])

绘图

mpl.colors.ListedColormap([‘g’, ‘r’, ‘b’])

matplotlib.colors.ListedColormap类用于从颜色列表创建colarmap对象。
我们有时希望图表元素的颜色与数据集中某个变量的值相关,颜色随着该变量值的变化而变化,以反映数据变化趋势、数据的聚集、分析者对数据的理解等信息,这时,我们就要用到 matplotlib 的颜色映射(colormap)功能,即将数据映射到颜色。

https://blog.csdn.net/qq_38486203/article/details/80578260

plt.scatter()

绘制散点图。一系列参数需要注意。

https://blog.csdn.net/qiu931110/article/details/68130199

def scatter(x, y, #需要绘制的二维数据
s=None, #样本点大小
c=None, #颜色序列,可以为rgb颜色,也可以为一个序列(向量,不同数值即为不同颜色)
marker=None,#标记样式,如⭐,
 cmap=None, #colormap实例
 norm=None, vmin=None, vmax=None, 
alpha=None, linewidths=None, verts=None, edgecolors=None, hold=None, data=None, **kwargs)

plt.show()

显示图形。

https://www.itbaoku.cn/post/1606600/do

总结

numpy库和matplotlib库的很多功能需要多加学习!

你可能感兴趣的:(算法,kmeans,聚类)