import random
import numpy as np
import matplotlib.pyplot as plt
nn = 100
def get_clusters():
mean1 = [0,0]
cov1 = [[0.1,0],[0,0.1]]
data1 = np.random.multivariate_normal(mean1,cov1,nn)
mean2 = [1.25,1.25]
cov2 = [[0.1,0],[0,0.1]]
data2 = np.append(data1,
np.random.multivariate_normal(mean2,cov2,nn),
0)
mean3 = [-1.25,1.25]
cov3 = [[0.1,0],[0,0.1]]
data3 = np.append(data2,
np.random.multivariate_normal(mean3,cov3,nn),
0)
return np.round(data3,4)
data = get_clusters()
"""
def show_scatter(data):
x,y = data.T
plt.plot(x[:nn],y[:nn],'b+');
plt.plot(x[nn:2*nn],y[nn:2*nn],'r+');
plt.plot(x[2*nn:3*nn],y[2*nn:3*nn],'g+');
plt.axis()
plt.title("scatter")
plt.xlabel("x")
plt.ylabel("y")
show_scatter(data)
"""
k = 4
point_ind = random.sample([i for i in range(3*nn)],k)
center_point = data[point_ind,:]
while 1:
data_label = np.zeros((3*nn,1),dtype=np.int)
for i in range(3*nn):
distance = np.zeros(k)
for j in range(k):
distance[j] = np.linalg.norm(data[i,:]-center_point[j,:])
norm = np.argmin(distance)
data_label[i] = norm
new_center_point = np.zeros((k,2))
number_label = np.zeros(k,dtype=np.int)
m = 0
for i in range(k):
for j in range(3*nn):
if data_label[j] == i:
number_label[i] += 1
new_center_point[i,:] += data[j,:]
new_center_point[i,:] /= number_label[i]
if np.linalg.norm(new_center_point[i,:] - center_point[i,:]) < 0.1:
m += 1
if m == k:
break
else:
center_point = new_center_point
plt.figure(1)
for i in range(3*nn):
if data_label[i] == 0:
plt.plot(data[i,0],data[i,1],'r*')
plt.plot(center_point[0,0],center_point[0,1],'ko')
elif data_label[i] == 1:
plt.plot(data[i,0],data[i,1],'g*')
plt.plot(center_point[1,0],center_point[1,1],'ko')
elif data_label[i] == 2:
plt.plot(data[i,0],data[i,1],'b*');
plt.plot(center_point[2,0],center_point[2,1],'ko')
elif data_label[i] == 3:
plt.plot(data[i,0],data[i,1],'y*')
plt.plot(center_point[3,0],center_point[3,1],'ko')
plt.axis()
plt.title("scatter")
plt.xlabel("x")
plt.ylabel("y")
plt.show()
"""
from sklearn.cluster import KMeans # 导入k-means
km = KMeans(n_clusters=k)
# 训练数据
km.fit(data)
# 进行预测
y_predict = km.predict(data)
# 获取聚类中心
center = km.cluster_centers_
"""
结果
