# -*- coding: utf-8 -*-
"""
Created on Sun Dec 09 14:43:47 2012
@author: Administrator
"""
import pylab
import numpy
import matplotlib.pyplot as plt
docs=numpy.array([[0,0],
[1,0],
[0,1],
[1,1],
[2,1],
[1,2],
[2,2],
[3,2],
[6,6],
[7,6],
[8,6],
[6,7],
[7,7],
[8,7],
[9,7],
[7,8],
[8,8],
[9,8],
[8,9],
[9,9]])
plt.plot(docs[:,0], docs[:,1], 'ko', label='line 3', linewidth=2)
#plt.show()
#plt.plot([1,2,3], [1,2,3], 'go-', label='line 1', linewidth=2)
#plt.plot([1,2,3], [1,4,9], 'rs', label='line 2')
#plt.axis([0, 4, 0, 10])
#plt.legend()
#plt.show()
#设置k个中心点
markers=["o","x"]
colors=["r","g","b","c","r","g","b","c","r","g","b","c","r","g","b","c"]
k=2
centers={}
import random
li=range(0,len(docs))
random.shuffle(li)
print li[:k],li
print len(docs)
for idx,elem in enumerate(li[:k]):
x,y=docs[elem]
centers[(x,y)]=set()
print x,y
pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])
# pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])
# plt.plot(x,y,colors[idx],s=500)
plt.show()
print centers
for idx,doc in enumerate(docs):
print idx,doc
x1,y1=doc
r=[]
for key,val in centers.iteritems():
x2,y2=key
d=(x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)
r.append((key,d))
key=min(r,key=lambda x:x[1])[0]
centers[key].add(idx)
print centers
plt.plot(docs[:,0], docs[:,1], 'ko', label='line 3', linewidth=2)
idx=0
for key,elems in centers.iteritems():
x,y=key
pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])
for elem in elems:
a,b=docs[elem]
pylab.scatter(a,b, marker=markers[idx], s = 200, linewidths=2,c=colors[idx])
idx+=1
plt.show()
while(True):
old_centers=centers.copy()
centers.clear()
for elem in old_centers.values():
mysum=numpy.zeros(2)
for idx in elem:
mysum+=numpy.array(docs[idx])
mean=(mysum/len(elem)).tolist()
centers[tuple(mean)]=set()
for idx,doc in enumerate(docs):
print idx,doc
x1,y1=doc
r=[]
for key,val in centers.iteritems():
x2,y2=key
d=(x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)
r.append((key,d))
key=min(r,key=lambda x:x[1])[0]
centers[key].add(idx)
print centers
plt.plot(docs[:,0], docs[:,1], 'ko', label='line 3', linewidth=2)
idx=0
for key,elems in centers.iteritems():
x,y=key
pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])
for elem in elems:
a,b=docs[elem]
pylab.scatter(a,b, marker=markers[idx], s = 200, linewidths=2,c=colors[idx])
idx+=1
colors=colors[2:]
idx=0
for key,elems in centers.iteritems():
x,y=key
pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])
idx+=1
plt.show()
count=0
for value in old_centers.values():
if value in centers.values():
count+=1
if count==k:
break
print "==="*10
print centers.values()
print old_centers.values()