大家自行百度吧,我懒得码字了
推荐一下原理原理https://blog.csdn.net/jinshengtao/article/details/30258833
直接上代码了,看不懂,就参照一下原理
# author: wdq
# contact: [email protected]
# datetime:2022/3/15 17:40
# software: PyCharm
import random
from collections import Counter
from typing import List
import numpy as np
from numpy import ndarray
class MeanShift:
def __init__(self, nums: ndarray, band_width: float):
"""
:param nums: 要划分的ndarray
:param band_width: 窗口大小
"""
# 要划分的ndarray
self.__nums = nums
# 窗口大小
self.__band_width = band_width
# 停止步长
self.__stop_band_width = 10 ** -4 * self.__band_width
# 访问数组
self.__is_visited = [False] * self.__nums.shape[0]
# 聚类中心
self.__cluster_centers = []
# 聚类
self.__cluster = []
def mean_shift(self) -> List[List[List[int]]]:
# 判断是否所有点都被访问过
while not self.__is_all_visited():
my_member = []
# 在没被访问的点随机选一个点
start_point = random.choice([i for i in range(self.__nums.shape[0]) if not self.__is_visited[i]])
my_mean = self.__nums[start_point]
while True:
# 得到到各点的距离,以及权重
distance, gaussian = self.__get_shift(my_mean)
# 找到在窗口的点
in_the_area = self.__find__points(distance)
# 保留当前的位置
old_mean = my_mean.copy()
# 得到新的位置
my_mean = self.__get_new_mean(gaussian, in_the_area)
# 将范围的点划到当次的聚类
my_member.extend(in_the_area)
# 更新当前的访问数组
self.__update_visited(in_the_area)
# 判断是否小于停止步长
if self.__get_distance(old_mean, my_mean) < self.__stop_band_width:
merge_width = None
# 遍历当前聚类
for i in range(len(self.__cluster_centers)):
# 判断中心点离得太近
if self.__get_distance(my_mean, self.__cluster_centers[i]) < self.__band_width / 2:
merge_width = i
break
# 如果太近了就合并这2个聚类
if merge_width is not None:
# 合并中心点
self.__cluster_centers[merge_width] = self.__get_new_center(my_mean,
self.__cluster_centers[merge_width])
# 合并聚类中的点
self.__cluster[merge_width].extend(my_member)
# 否则就添加一个聚类
else:
self.__cluster_centers.append(my_mean.tolist())
self.__cluster.append(my_member)
break
# 返回分好类的结果
return self.__get_result()
def __is_all_visited(self) -> bool:
"""
:return: 是否全部访问
"""
# 遍历访问数组
for i in self.__is_visited:
if not i:
return False
return True
def __get_distance(self, start: any, end: any) -> float:
"""
:param start: 起始点
:param end: 终点
:return: 两点之间的距离
"""
# 类型转换
if type(start) != ndarray:
start = np.array(start)
if type(end) != ndarray:
end = np.array(end)
# 返回欧式距离
return np.linalg.norm(start - end)
def __get_shift(self, start: ndarray) -> (ndarray, ndarray):
"""
:param start: 开始的点
:return: 计算滑动的距离
"""
# 距离
distance = np.zeros((self.__nums.shape[0], 1))
# 权重
gaussian = np.zeros((self.__nums.shape[0], 1))
for i in range(distance.shape[0]):
temp = self.__get_distance(start, self.__nums[i])
gaussian[i] = self.__gaussian_kernel(temp, self.__band_width)
distance[i] = temp
return distance, gaussian
def __gaussian_kernel(self, distance: float, bandwidth: float) -> float:
"""
高斯核函数
:param distance: 距离
:param bandwidth: 窗口大小
:return: 权重
"""
return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * (distance / bandwidth) ** 2)
def __get_new_mean(self, gaussian: ndarray, in_the_area: List[int]) -> ndarray:
"""
:param gaussian: 权重
:param in_the_area: 在区域的点
:return:
"""
# 权重
weight = 0.
# 在范围的点
new_mean = np.array([self.__nums[i].tolist() for i in in_the_area])
for i in range(len(in_the_area)):
new_mean[i] = new_mean[i] * gaussian[in_the_area[i]]
weight += gaussian[in_the_area[i]]
# 对范围的点进行加权,并算出漂移到的点
return np.sum(new_mean, axis=0) / weight if weight != 0 else np.sum(new_mean, axis=0)
def __find__points(self, distance: ndarray) -> List[int]:
"""
:param distance: 距离ndarray
:return: 在窗口大小内的点
"""
return [i for i, j in enumerate(distance) if j < self.__band_width ** 2]
def __update_visited(self, in_the_area: List[int]) -> None:
"""
更新访问过的点
:param in_the_area: 在窗口大小内的点
:return:
"""
for i in in_the_area:
self.__is_visited[i] = True
def __get_new_center(self, mymean: ndarray, old_center: List[int]) -> List[int]:
"""
合并中心点
:param mymean: 现在的中心点
:param old_center: 以前的中心点
:return:
"""
return [(i + j) / 2 for i, j in zip(mymean.tolist(), old_center)]
def __get_result(self) -> List[List[List[int]]]:
"""
将结果分好类并返回
这段代码比较丑陋,将就看看,不看也行,我自己都不想看
大致意思就是找这些点应该分到那个类
:return:
"""
count = []
result = [[] for i in range(len(self.__cluster))]
# 计数,计出每个点到每个聚类的次数
for i in self.__cluster:
count.append(dict(Counter(i)))
belong = []
# 遍历找出每个点到到那个聚类的最大值,那我们就可以认为它在那个聚类
for num in range(len(self.__nums)):
# 最大次数的索引
index = 0
for i in range(1, len(count)):
if count[i].get(num, 0) > count[index].get(num, 0):
index = i
belong.append(index)
# 分类
for i in range(len(self.__nums)):
result[belong[i]].append(self.__nums[i].tolist())
# 把空的聚类移除
return [i for i in result if i]
测试代码
import matplotlib
from matplotlib import pyplot as plt
from sklearn import datasets
from MeanShift import MeanShift
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False
iris = datasets.load_iris() # 引入数据集
# 分的类不好就重新分,多试一哈
mean_shift = MeanShift(nums=iris.data, band_width=1.34) # 对于iris,窗口大小为1.34,别问为什么,别问,问就是好用
colors = ['red', 'green', 'blue', 'black', 'yellow']
a = mean_shift.mean_shift()
for i in range(len(a)):
for j in a[i]:
plt.scatter(j[0], j[1], c=colors[i])
plt.title("Mean-Shift")
plt.xlabel('萼片长度')
plt.ylabel('萼片宽度')
plt.show()
"""
___________.__ __ _____ _____ .__ ._.
\__ ___/| |__ _____ ____ | | __ ______ _/ ____\___________ / \ |__| ____| |
| | | | \\__ \ / \| |/ / / ___/ \ __\/ _ \_ __ \ / \ / \| |/ \ |
| | | Y \/ __ \| | \ < \___ \ | | ( <_> ) | \/ / Y \ | | \|
|____| |___| (____ /___| /__|_ \/____ > |__| \____/|__| \____|__ /__|___| /_
\/ \/ \/ \/ \/ \/ \/\/
"""
运行结果
标准答案
MeanShift算法
只用来学习,借鉴,错的话,欢迎批评和指导!