程序运行时会跳出matplotlib窗体,并中断程序,关闭窗体后程序会继续执行。
# -*- coding:utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
#下面两行,解决matplotlib中无法显示中文的问题
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']
#数据定义,每一行为一个坐标,你也可以自行修改
data='''
1,1
2,1
4,5
6,6
5,4
3,3
2,2
3,2
5,6
1,3
3,1
6,5
''';
#整理数据
#[0,0]的意思在每个坐标后加两个值,每个样本数据均为4个值,[x,y,0,0]
#第1个为分类,
#第2个作用循环进行分类时为与前次分类比较,看是否发生变化,变化则为1,当所有样本均未变化时结束分类
data=[x.split(',')+[0,0] for x in data.split('\n')]
data=list(filter(lambda x: len(x)==4,data))
data=np.array(data).astype(np.float)
#print(data)
#取得中心点,选取前K个样本,每个样本分为一类
k=2 #在本中因为在matplotlib按类显示不同的点,只设了四个显示,分类数别太多了
centroids=data.copy()[:k,:-2]#-2表示只取样本的坐标
def draw(centroids,data,title):
plt.axis([round((np.min(data,axis=0)-1)[0])
,round((np.max(data,axis=0)+1)[0])
,round((np.min(data,axis=0)-1)[1])
,round((np.max(data,axis=0)+1)[1])]) # 用于定义X,Y轴的范围
plt.title(title)
for index,center in enumerate(centroids):
colorStr='rgby'[index:index+1] #在本中因为在matplotlib按类显示不同的点,只设了四个显示,分类数别太多了
centerData=np.array(list(filter(lambda x:x[-2]==index+1 ,data)))
if len(centerData)>0 :plt.scatter(centerData[:,0],centerData[:,1],c=colorStr)
plt.scatter(center[0],center[1],c=colorStr,marker='x')
plt.show()
runtimes=0
changePointLength=-1
while changePointLength!=0:
runtimes+=1
draw(centroids,data,'第 %d 次'%runtimes + (' 首次仅显示中心点,因为所有点的未分类' if runtimes==1 else ''))
#3.2 计算所有每个样本与中心点的距离,这样得到每个中心点有哪些样本
for dataItem in data:
distances=np.sqrt(((centroids-dataItem[:-2])**2).sum(axis=1))#计算每个点与每个中心点的距离
minDisType=np.argmin(distances)+1 #取得取小的距离的分类号
if dataItem[-2]==minDisType:
dataItem[-1]=0 #如果分类结果未发生变化
else :
dataItem[-1]=1 #如果分类结果发生变化
dataItem[-2]=minDisType
print(data)
#3.3 计算每个中心点所有样本的坐标的平均值,做为新的坐标
for index,center in enumerate(centroids):
centerData=np.array(list(filter(lambda x:x[-2]==index+1 ,data)))[:,:-2] #得到每中心点包含哪些点
centerData=centerData.mean(axis=0)
center[0]=centerData[0]
center[1]=centerData[1]
#print(data)
changePointLength=len(list(filter(lambda x:x[-1]==1 ,data))) #看有几个点的分类发生变化,如果为零,则退出循环