import sys
import math
class Item(object):
field = ["age", "h"]
def __init__(self, name="", age=0.0, h=0.0):
self.name = name
self.age = age
self.h = h
def __eq__(self, other):
""" 对象是否一样 """
if self.__class__ != other.__class__:
return False
for field in self.field:
val1 = getattr(self, field, 0.0)
val2 = getattr(other, field, 0.0)
if val1 != val2:
return False
return True
def sub(self, other):
""" self和other相减的距离 """
if self.__class__ != other.__class__:
return sys.maxint
dob = 0
for field in self.field:
val1 = getattr(self, field, 0.0)
val2 = getattr(other, field, 0.0)
dob += math.pow(val1 - val2, 2)
return float(int(math.sqrt(dob)))
__sub__ = __rsub__ = lambda x, y: x.sub(y)
def __str__(self):
return (self.name or "None") + ":" + (",".join(["%s:%s" % (x, getattr(self, x)) for x in self.field]))
class Kmeans(object):
def __init__(self, objects, k):
self.objects = objects
self.k = k
self.init_objects = objects[0: k]
def com_put(self):
results = []
center_change = True
while center_change:
center_change = False
results = []
for index in range(0, self.k):
results.append([])
for obj in self.objects:
dists = {}
for i, i_obj in enumerate(self.init_objects):
dists[i] = i_obj - obj
dist_index = self.comput_order(dists)
results[dist_index].append(obj)
for index in range(0, self.k):
new_item = self.find_new_center(results[index])
old_item = self.init_objects[index]
if not new_item == old_item:
center_change = True
self.init_objects[index] = new_item
return results
def find_new_center(self, dists):
""" 找到中心点 """
ds = {}
new_item = self.objects[0].__class__()
if dists is None or len(dists) == 0: return new_item
for item in dists:
for index, field in enumerate(item.field):
ds[index] = ds.get(index, 0.0) + getattr(item, field, 0.0)
for index, field in enumerate(new_item.field):
ds[index] /= len(dists)
setattr(new_item, field, ds[index])
return new_item
def comput_order(self, dists):
""" 得到最短距离,并返回最短距离索引 """
m = 0
index = 0
for i, item in dists.items():
if i == len(dists) - 1:
break
if i == 0:
m = item
index = 0
dist1 = dists[i + 1]
if m > dist1:
m = dist1
index = i + 1
return index
l = [Item("p1", 5, 30), Item("p1", 30, 10), Item("p1", 21, 10),
Item("p1", 25, 20), Item("p1", 66, 20.5), Item("p1", 15, 10),
Item("p1", 21, 50)]
k = 3
results = Kmeans(l, k).com_put()
for i, x in enumerate(results):
print "#####category(%s)#####" % i
for item in x:
print item