K-means方法总结(附代码)

K-means方法总结(附代码)

这一周事情较多,不得已先放弃了验证码分割部分的卷积神经网络的学习,先写两篇关于聚类方法的内容,分别是k-means和混合高斯模型。因为之前的论文中有关于k-means方法的字符分割方法,所以就先学习k-menas的方法。下一篇是高斯混合模型的方法总结。

算法基本原理:

k-means方法较为简单,原理也特别的清晰。

  1. 随机选取k个点,并将作为第一轮迭代的中心点;
  2. 计算每一组数据到k个点中的距离,并将每一个点分配到k个簇中;
  3. 更新中心点:将k个簇中的每一个点的各属性的值进行加和,并计算均值,得到新的中心点;
  4. 多次迭代,直到中心点的值不再发生变化。

代码实现:

基本原理很简单,代码也同样的简单。废话少说,直接上代码:

# k-means算法的中心部分
def k_means(data_set, k, start_point):
    cluster = []
    end_point = []
    for t in range(100):
        # 保证最多的迭代次数为100,不论是否得到最后的结果,都不再继续运行
        for i in range(len(data_set)):
            distance = []
            # 计算每一条数据与选中的点的距离并分到距离最小的簇中
            for j in range(len(start_point)):
                a = pow((data_set[i][0]-start_point[j][0]), 2)
                b = pow((data_set[i][1]-start_point[j][1]), 2)
                distance.append(pow(a+b, 0.5))
                pass
            cluster.append(distance.index(min(distance)))
            distance.clear()
            # 得到更新以后的中心点
            end_point = update_start_point(data_set, cluster, k)
            pass
        # 如果中心点的数据保持不变,则结束循环
        if end_point == start_point:
            print("共迭代 "+str(t)+" 次")
            break
        else:
            if t == 99:
                break
            else:
                start_point = end_point
                cluster.clear()
        pass
    return cluster, end_point
    pass
# 中心点的更新函数:
def update_start_point(data_set, cluster, k):
    start_update = [[0, 0], [0, 0], [0, 0]]
    for i in range(k):
        num = 0
        for j in range(len(cluster)):
            if i == cluster[j]:
                # 得到每一个点所分的簇,并计算每一处簇中的所有点的均值来更新中心点
                num += 1
                start_update[i][0] += data_set[j][0]
                start_update[i][1] += data_set[j][1]
                pass
            pass
        if num == 0:
            start_update[i][0] = 0
            start_update[i][1] = 0
        else:
            start_update[i][0] = start_update[i][0]/num
            start_update[i][1] = start_update[i][1]/num
        pass
    return start_update
    pass

代码验证:

为了验证我们的代码,我们用两种方式进行验证:
1、用西瓜数据集进行验证(就是那个有密度和含糖量数据集)
2、用图片进行验证:用编写的程序对图片的内容进行分类,看最后的效果。
说干就干,首先是对西瓜数据集进行聚类的结果:
K-means方法总结(附代码)_第1张图片
这里的初始点我就简单的采用了开始的三个点,最后得到的图片聚类结果如上图所示。
接着我们用图片进行验证
在用图片进行验证之前,我们需要对图片进行一定的处理:

def get_file(img):
    filename = "data\\zebra.txt"
    with open(filename, "w") as f:
        h = img.shape[0]
        w = img.shape[1]
        for i in range(h):
            for j in range(w):
                f.write(str(img[i][j][0])+" "+str(img[i][j][1])+" "+str(img[i][j][2])+"\n")
    return filename

将图片中的像素点全部存放在txt文件中,接着就可以通过同样的方式对图片进行处理。
最后的结果如下图所示(应该一眼就能看到哪个是原图,哪个是聚类图吧):
K-means方法总结(附代码)_第2张图片
K-means方法总结(附代码)_第3张图片
最后附上源代码百度云盘地址:
k-means源代码、数据集及测试图片
提取码:p1vh
CSDN链接同样的将会在下一篇列出,而下一篇是高斯混合模型的使用,希望大家多多指正。

你可能感兴趣的:(机器学习)