python plot kmeans 演示版本

# -*- coding: utf-8 -*-

"""

Created on Sun Dec 09 14:43:47 2012

 

@author: Administrator

"""

import pylab

import numpy

import matplotlib.pyplot as plt

docs=numpy.array([[0,0],

[1,0],

[0,1],

[1,1],

[2,1],

[1,2],

[2,2],

[3,2],

[6,6],

[7,6],

[8,6],

[6,7],

[7,7],

[8,7],

[9,7],

[7,8],

[8,8],

[9,8],

[8,9],

[9,9]])

 

plt.plot(docs[:,0], docs[:,1], 'ko', label='line 3', linewidth=2)

 

#plt.show()

 

#plt.plot([1,2,3], [1,2,3], 'go-', label='line 1', linewidth=2)

#plt.plot([1,2,3], [1,4,9], 'rs', label='line 2')

#plt.axis([0, 4, 0, 10])

#plt.legend()

#plt.show()

 

#设置k个中心点

markers=["o","x"]

colors=["r","g","b","c","r","g","b","c","r","g","b","c","r","g","b","c"]

k=2

centers={}

import random

li=range(0,len(docs))

random.shuffle(li)

print li[:k],li

print len(docs)

for idx,elem in enumerate(li[:k]):

x,y=docs[elem]

centers[(x,y)]=set()

print x,y

pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])

# pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])

# plt.plot(x,y,colors[idx],s=500)

plt.show()

print centers

 

for idx,doc in enumerate(docs):

print idx,doc

x1,y1=doc

r=[]

for key,val in centers.iteritems():

x2,y2=key

d=(x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)

r.append((key,d))

key=min(r,key=lambda x:x[1])[0]

centers[key].add(idx)

 

print centers

 

 

plt.plot(docs[:,0], docs[:,1], 'ko', label='line 3', linewidth=2)

idx=0

for key,elems in centers.iteritems():

x,y=key

pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])

for elem in elems:

a,b=docs[elem]

pylab.scatter(a,b, marker=markers[idx], s = 200, linewidths=2,c=colors[idx])

idx+=1

 

 

plt.show()

 

while(True):

old_centers=centers.copy()

centers.clear()

 

 

for elem in old_centers.values():

mysum=numpy.zeros(2)

for idx in elem:

mysum+=numpy.array(docs[idx])

mean=(mysum/len(elem)).tolist()

centers[tuple(mean)]=set()

 

 

 

 

 

 

 

for idx,doc in enumerate(docs):

print idx,doc

x1,y1=doc

r=[]

for key,val in centers.iteritems():

x2,y2=key

d=(x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)

r.append((key,d))

key=min(r,key=lambda x:x[1])[0]

centers[key].add(idx)

 

print centers

 

 

 

 

 

 

plt.plot(docs[:,0], docs[:,1], 'ko', label='line 3', linewidth=2)

idx=0

for key,elems in centers.iteritems():

x,y=key

pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])

for elem in elems:

a,b=docs[elem]

pylab.scatter(a,b, marker=markers[idx], s = 200, linewidths=2,c=colors[idx])

idx+=1

 

 

 

 

 

colors=colors[2:]

idx=0

for key,elems in centers.iteritems():

x,y=key

pylab.scatter(x,y, marker=markers[idx], s = 500, linewidths=2,c=colors[idx])

idx+=1

 

 

plt.show()

 

 

 

 

count=0

for value in old_centers.values():

if value in centers.values():

count+=1

if count==k:

break

 

print "==="*10

print centers.values()

print old_centers.values()

 

你可能感兴趣的:(python)