想搞点有意思的东西,参考[K-means聚类算法(https://www.cnblogs.com/jerrylead/archive/2011/04/06/2006910.html)做了一些实现。记录如下。
# coding=UTF-8
import numpy as np
import math
import matplotlib.pyplot as plt
import copy
class Group(object):
''' group structure
1) mainly provide group_center update;
2) pay attention that, all members are public,
no interface is avaliable
'''
center_ar = np.array([0, 0])
members = []
def UpCenter(self):
sum_xy = np.sum(self.members, axis=0)
mean = sum_xy/len(self.members)
pre_center = self.center_ar
shortest_dis = Group.distance(self.center_ar, mean)
for member in self.members:
if Group.distance(member, mean) < shortest_dis:
pre_center = member
shortest_dis = Group.distance(member, mean)
self.center_ar = pre_center
print(self.center_ar)
@staticmethod
def distance(ar_x, ar_y):
''' calculate the distance of array_x and array_y
1) assume that ar_x and ar_y has only one object;
2) assume that ar_x and ar_y should has same dimensional
'''
(a1, ) = np.shape(ar_x)
(a2, ) = np.shape(ar_y)
if a1 != a2:
return 1000
diff_mat = ar_x - ar_y
sq_diff_mat = diff_mat ** 2
sq_distance = np.sum([sq_diff_mat], axis=1)
return (sq_distance ** 0.5)[0]
class K_MEANS(object):
group_num = 1
groups = []
data_arrays = []
def __init__(self, groups = 1, data_ars = [], max_times = 10):
self.group_num = groups
last_centers = []
now_centers = []
for cnt in range(self.group_num):
# make sure no index out of range
cnt = (cnt + len(data_ars)) % len(data_ars)
this_group = Group()
this_group.center_ar = data_ars[cnt]
self.groups.append(this_group)
now_centers.append(this_group.center_ar)
self.data_arrays = data_ars
for time in range(max_times):
last_centers.clear()
last_centers = copy.deepcopy(now_centers)
now_centers.clear()
now_centers = []
self.UpdateOnce()
for cnt in range(len(self.groups)):
now_centers.append(self.groups[cnt].center_ar)
if not K_MEANS.is_changed(now_centers, last_centers):
print("trained %d times" % time)
break
@staticmethod
def is_changed(ars1, ars2):
''' judge if two array lists diff
1) used to judge if each groups' center chagned
'''
change_sum = 0
for cnt in range(len(ars1)):
change_sum += np.sum(ars1[cnt] - ars2[cnt])
return False if change_sum == 0 else True
def UpdateOnce(self):
''' update group info once
'''
for group in self.groups:
group.members.clear()
group.members = []
for data_ar in self.data_arrays:
short_group = self.groups[-1]
short_dist = Group.distance(data_ar, short_group.center_ar)
for group in self.groups:
this_dist = Group.distance(data_ar, group.center_ar)
if this_dist < short_dist:
short_dist = this_dist
short_group = group
short_group.members.append(data_ar)
for group in self.groups:
group.UpCenter()
# try to create two group data
x1 = [x for x in range(100)]
x2 = [x for x in range(150, 250)]
x = x1 + x2
y = [0]*len(x)
ars = []
for i, t in enumerate(x):
np.random.seed(t)
tmp = t % 10 + 1
y[i] = tmp*np.exp(np.random.random())
ars.append(np.array([t, y[i]]))
# use K_means to clarrify data
k_mean = K_MEANS(2, ars)
# show the result
group1 = k_mean.groups[0]
group2 = k_mean.groups[1]
x1 = []; x2 = []; y1 = []; y2 = []
for mem in group1.members:
x1.append(mem[0])
y1.append(mem[1])
cent1_x = [group1.center_ar[0]]
cent1_y = [group1.center_ar[1]]
for mem in group2.members:
x2.append(mem[0])
y2.append(mem[1])
cent2_x = [group2.center_ar[0]]
cent2_y = [group2.center_ar[1]]
# https://matplotlib.org/api/_as_gen/matplotlib.pyplot.figure.html#matplotlib.pyplot.figure
plt.figure(1)
# https://matplotlib.org/api/_as_gen/matplotlib.pyplot.plot.html#matplotlib.pyplot.plot
# show different group in different color, and each groups' center in different type
# for obviously showing, notice the color of group center
plt.plot(x1, y1, 'b.',cent1_x, cent1_y, 'go', x2, y2, 'g.', cent2_x, cent2_y, 'bo')
plt.figure(2)
plt.plot(x, y, '.')
plt.show()