西瓜书《机器学习》课后答案——chapter9 _9.4

编程实现k均值算法,设置三组不同的k值、三组不同初始中心点,在西瓜数据集4.0上进行实验比较,并讨论什么样的初始中心有助于得到好结果。

下面所有图中的横坐标表示密度,纵坐标表示含糖率。

首先看看4.0数据集:
西瓜书《机器学习》课后答案——chapter9 _9.4_第1张图片

代码

#-*- coding:utf-8 -*-
""" @Author: Victoria @Date: 2017.10.24 12:00 """
import random
import matplotlib.pyplot as plt
import xlrd
from copy import deepcopy

style = "*+o."
color = "kgybp"

class KMeans():
    def __init__(self, k):
        self.k = k

    def train(self, X):
        self.N = len(X)
        self.d = len(X[0])
        self.X = X

        self.init()
        self.init_centers = deepcopy(self.centers)
        #self.centers = [[0.403, 0.237], [0.343, 0.099], [0.478, 0.437]]
        #self.centers = [[0.403, 0.237], [0.343, 0.099], [0.532, 0.472]]
        print self.centers

        Js = []

        iter = 0
        while(1):
            iter += 1
            print "iteration: {}".format(iter)
            clusters_X  = {}
            clusters_y = {}
            for k in range(self.k):
                clusters_X[k] = []
                clusters_y[k] = []

            #cluster each sample to nearest cluster
            for i, x in enumerate(self.X):
                #print i, x 
                label = self.cluster_sample(x)
                #print "label of x: ", label
                clusters_X[label].append(x)
                clusters_y[label].append(i+1)

            self.plot_clusters(iter, clusters_X)

            #computer centers for all clusters
            old_centers = deepcopy(self.centers)
            for k in range(self.k):
                self.centers[k] = self.compute_center(clusters_X[k])

            J_new = self.J_cost(clusters_X)
            Js.append(J_new)


            diff = 0
            for k in range(self.k):
                diff += self.dist(old_centers[k], self.centers[k])
            if diff < 1e-3:
                self.clusters_X = clusters_X
                break

        print "iter: ", iter        
        self.plot_J(Js)
        self.plot_clusters(iter, self.clusters_X)

    def plot_J(self, Js):
        plt.figure()            
        plt.plot(range(len(Js)), Js)
        plt.savefig("figures/k={}_cost.png".format(self.k))

    def plot_clusters(self, iter, clusters_X):

        plt.figure()
        for k in range(self.k):
            for x in clusters_X[k]:
                if x in self.init_centers:
                    plt.plot(x[0], x[1], style[k]+'r')
                else:
                    plt.plot(x[0], x[1], style[k]+color[k])
        plt.savefig("figures/k={}_cluster_iter{}.png".format(self.k, iter))

    def predict(self):
        pass

    def init(self):
        self.centers = []
        for k in range(self.k):
            index = random.randint(0, self.N-1)
            self.centers.append(self.X[index])


    def compute_center(self, X):
        center = []
        for i in range(self.d):
            sum = 0
            for x in X:
                sum += x[i]
            center.append(float(sum) / len(X))
        return center

    def cluster_sample(self, x):
        min_dist = float('inf')
        for k in range(self.k):
            dist_to_k = self.dist(x, self.centers[k])
            #print "dist_to_k: ",dist_to_k
            if min_dist > dist_to_k:
                 label = k
                 min_dist  = dist_to_k

        return label

    def dist(self, x, y):
        sum = 0
        for i in range(self.d):
            sum += (x[i] - y[i])**2
        return sum

    def J_cost(self, clusters_X):
        J = 0
        for k in range(self.k):
            for x in clusters_X[k]:
                J += self.dist(x, self.centers[k])
        return J

def main():
    workbook = xlrd.open_workbook("4.0.xlsx")
    sheet = workbook.sheet_by_name('Sheet1')
    X = []
    for i in range(30):
        X.append(sheet.col_values(i)[0:2])
    y = sheet.row_values(2)
    plt.figure()
    for i in range(30):
        plt.plot(X[i][0], X[i][1], 'k.')
    plt.savefig("figures/samples.png")

    k_means = KMeans(k=2)
    k_means.train(X)

if __name__ == '__main__':
    main()

k=2时:(图中红色表示初始中心点)
西瓜书《机器学习》课后答案——chapter9 _9.4_第2张图片

k=3时:用书中的初始化方法,但是得到的结果有一个点不一样(欢迎大家来找茬)
西瓜书《机器学习》课后答案——chapter9 _9.4_第3张图片

k=4时:
西瓜书《机器学习》课后答案——chapter9 _9.4_第4张图片

问题:
如果某次迭代时,某个类簇为空怎么办?

你可能感兴趣的:(机器学习)