【ML】Mean-Shift 原理 + 实践(基于sklearn)

【ML】Mean-Shift 原理 + 实践(基于sklearn)

  • 原理
  • 实践
    • 生成数据
    • 训练
    • 预测+评估

原理

  1. 取数据集中的一个点为X,以此点为中心画一个半径为R的圆,圆内共有点数量假设为K。
  2. 以此点为起点,其他圆内点为终点计算出所有向量并相加除以K得到meanshift向量M(x)。
  3. 令X=X+M(x),然后从第一步继续开始,然后迭代此过程直到中心点(质心)不变。
  4. 再取剩下的其他点,进行此过程,直到所有点都计算完成。
  5. 每个点计算质心时,迭代一定次数,当质心变化小于某个阈值时停止迭代,并搜索当前质心附近质心进行归类(小于一定阈值内的质心)。

演示:
【ML】Mean-Shift 原理 + 实践(基于sklearn)_第1张图片

实践

生成数据

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs

centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

训练

bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print(labels_unique)
print("number of estimated clusters : %d" % n_clusters_)

预测+评估

y_predict = ms.predict(X)
from matplotlib import pyplot as plt
plt.figure()
plt.scatter(X[:,0],X[:,1],c=y_predict)

【ML】Mean-Shift 原理 + 实践(基于sklearn)_第2张图片

你可能感兴趣的:(机器学习,算法,python,sklearn,python)