聚类算法(2)--Mean Shift

目录

mean shift 算法理论

基本的Mean Shift向量形式

sklearn参数

python—sklearn实例演示

PS:


Mean shift 算法是基于核密度估计的爬山算法,可用于聚类、图像分割、跟踪等,因为最近搞一个项目,涉及到这个算法的图像聚类实现,因此这里做下笔记。

mean shift 算法理论

 Mean-shift(即:均值迁移)的基本思想:在数据集中选定一个点,然后以这个点为圆心,r为半径,画一个圆(二维下是圆),求出这个点到所有点的向量的平均值,而圆心与向量均值的和为新的圆心,然后迭代此过程,直到满足一点的条件结束。(Fukunage在1975年提出)

Mean-Shift是一种基于滑动窗口的聚类算法。也可以说它是一种基于质心的算法,这意思是它是通过计算滑动窗口中的均值来更新中心点的候选框,以此达到找到每个簇中心点的目的。然后在剩下的处理阶段中,对这些候选窗口进行滤波以消除近似或重复的窗口,找到最终的中心点及其对应的簇。

整体的图例如下:

聚类算法(2)--Mean Shift_第1张图片

分步骤演示:

步骤1:在指定的区域内计算偏移均值(如下图的黄色的圈)

聚类算法(2)--Mean Shift_第2张图片

步骤2:移动该点到偏移均值点处

聚类算法(2)--Mean Shift_第3张图片

步骤3: 重复上述的过程(计算新的偏移均值,移动)

聚类算法(2)--Mean Shift_第4张图片

聚类算法(2)--Mean Shift_第5张图片

聚类算法(2)--Mean Shift_第6张图片

聚类算法(2)--Mean Shift_第7张图片

聚类算法(2)--Mean Shift_第8张图片

 

我们再用一个比较直观的实例看一下:

聚类算法(2)--Mean Shift_第9张图片

聚类算法(2)--Mean Shift_第10张图片

聚类算法(2)--Mean Shift_第11张图片

聚类算法(2)--Mean Shift_第12张图片

基本的Mean Shift向量形式

聚类算法(2)--Mean Shift_第13张图片

sklearn参数

[class sklearn.cluster.MeanShift]

bandwidth=None: float,高斯核函数的半径(或带宽),如果没有给定,则使用sklearn.cluster.estimate_bandwidth 自动估计带宽;

seeds=None: array, shape=[n_samples, n_features],我理解的 seeds 是初始化的质心,如果为 None 并且 bin_seeding=True,就用 clustering.get_bin_seeds 计算得到;

bin_seeding=False: boolean,在没有设置 seeds 时起作用,如果 bin_seeding=True,就用 clustering.get_bin_seeds 计算得到质心,如果 bin_seeding=False,则设置所有点为质心;

min_bin_freq=1: int,clustering.get_bin_seeds 的参数,设置的最少质心个数; 
以上三个参数要结合起来理解;

cluster_all=True: boolean,如果为 True,所有的点都会被聚类,包括不在任何核内的孤立点,其会选择一个离自己最近的核;如果为 False,孤立点的类标签为 -1;

n_jobs=1: int,多线程; 
                    -1:使用所有的cpu; 
                    1:不使用多线程; 
                  -2:如果 n_jobs<0,(n_cpus + 1 + n_jobs)个cpu被使用,所以 n_jobs=-2 时,所有的cpu中只有一块不被使用;
 

主要属性

          cluster_centers_ : 数组类型。计算出的聚类中心的坐标。

          labels_ :数组类型。每个数据点的分类标签。

python—sklearn实例演示

我们运用K-means里面的一个数据样本:

import numpy as np
cluster1 = np.random.uniform(0.5, 1.5, (2, 10))
cluster2 = np.random.uniform(3.5, 4.5, (2, 10))
X = np.hstack((cluster1, cluster2)).T#合并数据
plt.figure()
plt.axis([0, 5, 0, 5])
plt.grid(True)
plt.plot(X[:,0],X[:,1],'k.')

聚类算法(2)--Mean Shift_第14张图片

导入相关包

from sklearn.cluster import MeanShift, estimate_bandwidth
bandwidth1 = estimate_bandwidth(X, quantile=0.2)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)
number of estimated clusters : 2
import matplotlib.pyplot as plt
from itertools import cycle

plt.figure(1)
plt.clf()

colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
for k, col in zip(range(n_clusters_), colors):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(X[my_members, 0], X[my_members, 1], col + '.')
    plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
             markeredgecolor='k', markersize=14)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()

聚类算法(2)--Mean Shift_第15张图片

PS:

bandwidth=None: float,高斯核函数的半径(或带宽),如果没有给定,则使用sklearn.cluster.estimate_bandwidth 自动估计带宽

我们上面就是运用了sklearn.cluster.estimate_bandwidth计算了bandwidth

对于bandwidth我们简单解释下

estimate_bandwidth(Xquantile=0.3n_samples=Nonerandom_state=0n_jobs=None)

X : array-like, shape=[n_samples, n_features]

Input points.

quantile : float, default 0.3

should be between [0, 1] 0.5 means that the median of all pairwise distances is used.

n_samples : int, optional

The number of samples to use. If not given, all samples are used.

random_state : int, RandomState instance or None (default)

The generator used to randomly select the samples from input points for bandwidth estimation. Use an int to make the randomness deterministic. See Glossary.

n_jobs : int or None, optional (default=None)

The number of parallel jobs to run for neighbors search. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors. See Glossary for more details.

 举个例子对quantile的影响来看一下:

import numpy as np
X0 = np.array([7, 5, 7, 3, 4, 1, 0, 2, 8, 6, 5, 3])
X1 = np.array([5, 7, 7, 3, 6, 4, 0, 2, 7, 8, 5, 7])
X00=np.array(list(zip(X0, X1))).reshape(len(X0), 2)#组合数据
import matplotlib.pyplot as plt 
plt.figure()
plt.scatter(X00[:, 0], X00[:, 1],c='b')#原始数据散点图

聚类算法(2)--Mean Shift_第16张图片

from sklearn.cluster import MeanShift, estimate_bandwidth
for a in list(range(2,10,1)):
    b=a/10
    bandwidth = estimate_bandwidth(X00, quantile=b)
    ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    ms.fit(X00)
    labels = ms.labels_
    cluster_centers = ms.cluster_centers_

    labels_unique = np.unique(labels)
    n_clusters_ = len(labels_unique)

    print("number of estimated clusters : %d" % n_clusters_)
number of estimated clusters : 5
number of estimated clusters : 5
number of estimated clusters : 3
number of estimated clusters : 2
number of estimated clusters : 2
number of estimated clusters : 1
number of estimated clusters : 1
number of estimated clusters : 1

我们可以看到随着 quantile的增大,聚类的数量随之减少。


for a in list(range(2,10,1)):
    b=a/10
    
    bandwidth = estimate_bandwidth(X00, quantile=b)
    ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    ms.fit(X00)
    labels = ms.labels_
    cluster_centers = ms.cluster_centers_

    labels_unique = np.unique(labels)
    n_clusters_ = len(labels_unique)

    plt.figure(figsize=(16, 16))
    colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
    for k, col in zip(range(n_clusters_), colors):
        my_members = labels == k
        cluster_center = cluster_centers[k]
        plt.plot(X00[my_members, 0], X00[my_members, 1], col + '.')
        plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
             markeredgecolor='k', markersize=14)
    plt.title('Estimated number of clusters: %d' % n_clusters_)
    plt.show()
plt.savefig('test1')

聚类算法(2)--Mean Shift_第17张图片

聚类算法(2)--Mean Shift_第18张图片

你可能感兴趣的:(聚类算法)