K-means聚类算法原理与实现——笔记梳理

转载于:https://www.toutiao.com/a6690435044869145101/
转载于:https://blog.csdn.net/ten_sory/article/details/81016748
今天我们学习KNN算法的好兄弟K-means算法又叫作K均值算法。

K-means中心思想

事先确定常数K,常数K意味着最终的聚类类别数。首先随机选定初始点为质心,并通过计算每一个样本与质心之间的相似度(这里为欧式距离),将样本点归到最相似的类中。接着,重新计算每个类的质心(即为类中心),重复这样的过程,直到质心不再改变,最终就确定了每个样本所属的类别以及每个类的质心。

一,K-means算法原理

K-means算法是最常用的一种聚类算法。算法的输入为一个样本集(或者称为点集),通过该算法可以将样本进行聚类,具有相似特征的样本聚为一类。针对每个点,计算这个点距离所有中心点最近的那个中心点,然后将这个点归为这个中心点代表的簇。一次迭代结束之后,针对每个簇类,重新计算中心点,然后针对每个点,重新寻找距离自己最近的中心点。如此循环,直到前后两次迭代的簇类没有变化。

下面通过一个简单的例子,说明K-means算法的过程。如下图所示,目标是将样本点聚类成3个类别。
K-means聚类算法原理与实现——笔记梳理_第1张图片

二,k-means算法流程

1.选择聚类的个数k(kmeans算法传递超参数的时候,只需设置最大的K值)

2.任意产生k个聚类,然后确定聚类中心,或者直接生成k个中心。

3.对每个点确定其聚类中心点。

4.再计算其聚类新中心。

5.重复以上步骤直到满足收敛要求。(通常就是确定的中心点不再改变。)

上述步骤的关键两点是:

  1. 找到距离自己最近的中心点。

  2. 更新中心点。

三,k-means优点与缺点

优点:

1、原理简单(靠近中心点) ,实现容易

2、聚类效果中上(依赖K的选择)

3、空间复杂度o(N)时间复杂度o(IKN,N为样本点个数,K为中心点个数,I为迭代次数)

缺点:

1、对离群点, 噪声敏感 (中心点易偏移)

2、很难发现大小差别很大的簇及进行增量计算

3、结果不一定是全局最优,只能保证局部最优(与K的个数及初值选取有关)

四、初始中心点的选取

初始中心点的选取,对聚类的结果影响较大。可以验证,不同初始中心点,会导致聚类的效果不同。如何选择初始中心点呢?

一个原则是:初始中心点之间的间距应该较大。

因此,可以采取的策略是:

step1:计算所有样本点之间的距离,选择距离最大的一个点对(两个样本C1, C2)作为2个初始中心点,从样本点集中去掉这两个点。

step2:如果初始中心点个数达到k个,则终止。如果没有,在剩余的样本点中,选一个点C3,这个点优化的目标是:
K-means聚类算法原理与实现——笔记梳理_第2张图片
这是一个双目标优化问题,可以约束其中一个,极值化另外一个,这样可以选择一个合适的C3点,作为第3个初始中心点。

如果要寻找第4个初始中心点,思路和寻找第3个初始中心点是相同的。

五、误差平方和(Sum of Squared Error)

误差平法和,SSE,用于评价聚类的结果的好坏,SSE的定义如下。

K-means聚类算法原理与实现——笔记梳理_第3张图片

一般情况下,k越大,SSE越小。假设k=N=样本个数,那么每个点自成一类,那么每个类的中心点为这个类中的唯一一个点本身,那么SSE=0。

六、k值得确定

一般k不会很大,大概在2~10之间,因此可以作出这个范围内的SSE-k的曲线,再选择一个拐点,作为合适的k值。
K-means聚类算法原理与实现——笔记梳理_第4张图片
可以看到,k=5之后,SSE下降的变得很缓慢了,因此最佳的k值为5。

七,代码实现

下面给出它的Python实现,其中中心点的选取是手动选择的。在代码中随机产生了一个样本,用于测试K-means算法:

# K-means Algorithm is a clustering algorithm
import numpy as np
import matplotlib.pyplot as plt
import random
 
 
def get_distance(p1, p2):
    diff = [x-y for x, y in zip(p1, p2)]
    distance = np.sqrt(sum(map(lambda x: x**2, diff)))
    return distance
 
 
# 计算多个点的中心
# cluster = [[1,2,3], [-2,1,2], [9, 0 ,4], [2,10,4]]
def calc_center_point(cluster):
    N = len(cluster)
    m = np.matrix(cluster).transpose().tolist()
    center_point = [sum(x)/N for x in m]
    return center_point
 
 
# 检查两个点是否有差别
def check_center_diff(center, new_center):
    n = len(center)
    for c, nc in zip(center, new_center):
        if c != nc:
            return False
    return True
 
 
# K-means算法的实现
def K_means(points, center_points):
 
    N = len(points)         # 样本个数
    n = len(points[0])      # 单个样本的维度
    k = len(center_points)  # k值大小
 
    tot = 0
    while True:             # 迭代
        temp_center_points = [] # 记录中心点
 
        clusters = []       # 记录聚类的结果
        for c in range(0, k):
            clusters.append([]) # 初始化
 
        # 针对每个点,寻找距离其最近的中心点(寻找组织)
        for i, data in enumerate(points):
            distances = []
            for center_point in center_points:
                distances.append(get_distance(data, center_point))
            index = distances.index(min(distances)) # 找到最小的距离的那个中心点的索引,
 
            clusters[index].append(data)    # 那么这个中心点代表的簇,里面增加一个样本
 
        tot += 1
        print(tot, '次迭代   ', clusters)
        k = len(clusters)
        colors = ['r.', 'g.', 'b.', 'k.', 'y.']  # 颜色和点的样式
        for i, cluster in enumerate(clusters):
            data = np.array(cluster)
            data_x = [x[0] for x in data]
            data_y = [x[1] for x in data]
            plt.subplot(2, 3, tot)
            plt.plot(data_x, data_y, colors[i])
            plt.axis([0, 1000, 0, 1000])
 
        # 重新计算中心点(该步骤可以与下面判断中心点是否发生变化这个步骤,调换顺序)
        for cluster in clusters:
            temp_center_points.append(calc_center_point(cluster))
 
        # 在计算中心点的时候,需要将原来的中心点算进去
        for j in range(0, k):
            if len(clusters[j]) == 0:
                temp_center_points[j] = center_points[j]
 
        # 判断中心点是否发生变化:即,判断聚类前后样本的类别是否发生变化
        for c, nc in zip(center_points, temp_center_points):
            if not check_center_diff(c, nc):
                center_points = temp_center_points[:]   # 复制一份
                break
        else:   # 如果没有变化,那么退出迭代,聚类结束
            break
 
    plt.show()
    return clusters # 返回聚类的结果
 
 
 
 
# 随机获取一个样本集,用于测试K-means算法
def get_test_data():
 
    N = 1000
 
    # 产生点的区域
    area_1 = [0, N / 4, N / 4, N / 2]
    area_2 = [N / 2, 3 * N / 4, 0, N / 4]
    area_3 = [N / 4, N / 2, N / 2, 3 * N / 4]
    area_4 = [3 * N / 4, N, 3 * N / 4, N]
    area_5 = [3 * N / 4, N, N / 4, N / 2]
 
    areas = [area_1, area_2, area_3, area_4, area_5]
    k = len(areas)
 
    # 在各个区域内,随机产生一些点
    points = []
    for area in areas:
        rnd_num_of_points = random.randint(50, 200)
        for r in range(0, rnd_num_of_points):
            rnd_add = random.randint(0, 100)
            rnd_x = random.randint(area[0] + rnd_add, area[1] - rnd_add)
            rnd_y = random.randint(area[2], area[3] - rnd_add)
            points.append([rnd_x, rnd_y])
 
    # 自定义中心点,目标聚类个数为5,因此选定5个中心点
    center_points = [[0, 250], [500, 500], [500, 250], [500, 250], [500, 750]]
 
    return points, center_points
 
 
if __name__ == '__main__':
 
    points, center_points = get_test_data()
    clusters = K_means(points, center_points)
    print('#######最终结果##########')
    for i, cluster in enumerate(clusters):
        print('cluster ', i, ' ', cluster)

由于样本点是随机产生的,所以每次运行的结果不相同。6次迭代得到的聚类结果分别如下图。
K-means聚类算法原理与实现——笔记梳理_第5张图片
控制台输出结果为:
在这里插入图片描述

八,k-means的经典案例与适用范围

1.文档分类器:根据标签、主题和文档内容将文档分为多个不同的类别。这是一个非常标准且经典的K-means算法分类问题。

2.物品传输优化:使用K-means算法的组合找到无人机最佳发射位置和遗传算法来解决旅行商的行车路线问题,优化无人机物品传输过程。

3.识别犯罪地点:使用城市中特定地区的相关犯罪数据,分析犯罪类别、犯罪地点以及两者之间的关联,可以对城市或区域中容易犯罪的地区做高质量的勘察。

4.客户分类:聚类能过帮助营销人员改善他们的客户群(在其目标区域内工作),并根据客户的购买历史、兴趣或活动监控来对客户类别做进一步细分。

5.球队状态分析:分析球员的状态一直都是体育界的一个关键要素。随着竞争越来愈激烈,机器学习在这个领域也扮演着至关重要的角色。如果你想创建一个优秀的队伍并且喜欢根据球员状态来识别类似的球员,那么K-means算法是一个很好的选择。

6.保险欺诈检测:利用以往欺诈性索赔的历史数据,根据它和欺诈性模式聚类的相似性来识别新的索赔。由于保险欺诈可能会对公司造成数百万美元的损失,因此欺诈检测对公司来说至关重要。

7.乘车数据分析:面向大众公开的Uber乘车信息的数据集,为我们提供了大量关于交通、运输时间、高峰乘车地点等有价值的数据集。分析这些数据不仅对Uber大有好处,而且有助于我们对城市的交通模式进行深入的了解,来帮助我们做城市未来规划。

8.网络分析犯罪分子:网络分析是从个人和团体中收集数据来识别二者之间的重要关系的过程。网络分析源自于犯罪档案,该档案提供了调查部门的信息,以对犯罪现场的罪犯进行分类。

9.呼叫记录详细分析:通话详细记录(CDR)是电信公司在对用户的通话、短信和网络活动信息的收集。将通话详细记录与客户个人资料结合在一起,这能够帮助电信公司对客户需求做更多的预测。

10.IT警报的自动化聚类:大型企业IT基础架构技术组件(如网络,存储或数据库)会生成大量的警报消息。由于警报消息可以指向具体的操作,因此必须对警报信息进行手动筛选,确保后续过程的优先级。对数据进行聚类可以对警报类别和平均修复时间做深入了解,有助于对未来故障进行预测。

你可能感兴趣的:(笔记梳理,机器学习,聚类,kmeans算法)