最小python代码实现Kmeans算法并图形化展示

  • Kmeans算法简介:

Kmeans算法基本思想是初始随机给定K个簇中心,按照距离最近的原则把待分类的样本点分到各个簇,然后根据平均值计算新的簇的质心。一直迭代直到两次簇心之间的迭代距离小于要求的值。

  • 基本步骤

    1. 未知簇心的数据集
    2. 初始化簇心
    3. 随机选取K个数据做为簇心
    4. 计算各个点到簇心的距离,并聚类到离该点最近的簇心上去
    5. 计算每一个簇类距离平均,并将这个平均值做为新的簇心
    6. 重复4
    7. 重复5
  • 数据:

    1. 数据来源是CSIE,里面的adult数据,主要分类有:
age: continuous.
workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
fnlwgt: continuous.
education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
education-num: continuous.
marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
sex: Female, Male.
capital-gain: continuous.
capital-loss: continuous.
hours-per-week: continuous.
native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
  1. 需要对数据做mapping,这样才能计算距离:
convert_work= {
    b'workclass':0.0,
    b'Private':1.0,
    b'Self-emp-not-inc':2.0,
    b'Self-emp-inc':3.0,
    b'Federal-gov':4.0,
    b'Local-gov':5.0,
    b'State-gov':6.0,
    b'Without-pay':7.0,
    b'Never-worked':8.0,
}
marry_convert = {
    b"marital-status":1.0,
    b"Married-civ-spouse":2.0,
    b"Divorced":3.0,
    b"Never-married":4.0,
    b"Separated":5.0,
    b"Widowed":6.0,
    b"Married-spouse-absent":7.0,
    b"Married-AF-spouse":8.0,
}
occupation_convert={
    b"Tech-support":1.0,
    b"Craft-repair":2.0,
    b"Other-service":3.0,
    b"Sales":4.0,
    b"Exec-managerial":5.0,
    b"Prof-specialty":6.0,
    b"Handlers-cleaners":7.0,
    b"Machine-op-inspct":8.0,
    b"Adm-clerical":9.0,
    b"Farming-fishing":10.0,
    b"Transport-moving":11.0,
    b"Priv-house-serv":12.0,
    b"Protective-serv":13.0,
    b"Armed-Forces":14.0,
}

relationship_convert= {
    b"Wife":1.0,
    b"Own-child":2.0,
    b"Husband":3.0,
    b"Not-in-family":4.0,
    b"Other-relative":5.0,
    b"Unmarried" :6.0,
}

race_convert = {
    b"White":1.0,
    b"Asian-Pac-Islander":2.0,
    b"Amer-Indian-Eskimo":3.0,
    b"Other":4.0,
    b"Black":5.0,
}

sex_convert = {
    b"Female":0,
    b"Male" :1,
}
country_convert = {
    b"United-States":1.0,
    b"Cambodia":2.0,
    b"England":3.0,
    b"Puerto-Rico":4.0,
    b"Canada":5.0,
    b"Germany":6.0,
    b"Outlying-US(Guam-USVI-etc)":7.0,
    b"India":8.0,
    b"Japan":9.0,
    b"Greece":10.0,
    b"South":11.0,
    b"China":12.0,
    b"Cuba":13.0,
    b"Iran":14.0,
    b"Honduras":15.0,
    b"Philippines":16.0,
    b"Italy":17.0,
    b"Poland":18.0,
    b"Jamaica":19.0,
    b"Vietnam":20.0,
    b"Mexico":21.0,
    b"Portugal":22.0,
    b"Ireland":23.0,
    b"France":24.0,
    b"Dominican-Republic":25.0,
    b"Laos":26.0,
    b"Ecuador":27.0,
    b"Taiwan":28.0,
    b"Haiti":29.0,
    b"Columbia":30.0,
    b"Hungary":31.0,
    b"Guatemala":32.0,
    b"Nicaragua":33.0,
    b"Scotland":34.0,
    b"Thailand":35.0,
    b"Yugoslavia":36.0,
    b"El-Salvador":37.0,
    b"Trinadad&Tobago":38.0,
    b"Peru":39.0,
    b"Hong":40.0,
    b"Holand-Netherlands.":41.0,
    b"":1.0
}
class_convert = {
    b">50K":0.0,
    b"<=50K":1.0,
    b"":1.0
}

convert_education={
    b"Bachelors":0.0,
    b"Some-college":1.0,
    b"11th":2.0,
    b"HS-grad":3.0,
    b"Prof-school":4.0,
    b"Assoc-acdm":5.0,
    b"Assoc-voc":6.0,
    b"9th":7.0,
    b"7th-8th":8.0,
    b"12th":9.0,
    b"Masters":10.0,
    b"1st-4th":11.0,
    b"10th":12.0,
    b"Doctorate":13.0,
    b"5th-6th":14.0,
    b"Preschool":15.0,
    b"":1.0
}
  • 读取文件:
def read_file(file,one_hot=False):
    np_file = np.genfromtxt(fname=file,delimiter=',',\
                            replace_space='',
                            converters={
                                1:lambda x:convert_work.get(x,1.0), \
                                3:lambda x:convert_education.get(x,7.0), \
                                5:lambda x:marry_convert.get(x,4.0), \
                                6:lambda x:occupation_convert.get(x,3.0), \
                                7:lambda x:relationship_convert.get(x,5.0), \
                                8:lambda x:race_convert.get(x,4.0), \
                                9:lambda x:sex_convert.get(x,1.0), \
                                13:lambda x:country_convert.get(x,1.0), \
                                14:lambda x:class_convert.get(x,1.0)}
                            )
    if one_hot:
        pass
    else:
        return np_file
  • 算法代码如下所示:
def computeDis(x,y):
    sub_ = x -y
    dist = np.sqrt(np.sum(np.power(sub_,2),axis=1))
    return dist

def initClusterPoint(data,k):
     ##随机选取K个点做为初始聚类点
     centers = data[np.random.randint(0,data.shape[0],k)]
     return centers

def getClusterPoint(data,k,centers):
    rows = data.shape[0]
    all_distance = np.empty([rows,k])
    for index,center in enumerate(centers):
        distanc = computeDis(data,center)
        all_distance[:,index] = distanc
        print(centers)
    small_dis = all_distance.min(axis=1)
    small_dis_index = all_distance.argmin(axis=1)
    ##更新centers
    new_centers = np.empty(centers.shape)
    for index,center in enumerate(centers):
        index_num = (small_dis_index == index).sum()
        data_index_small = data[small_dis_index==index]
        new_centers[index] = np.sum(data_index_small,axis=0) / index_num
    return small_dis_index,new_centers
if __name__ == "__main__":
    data = read_file('./adult.txt')
    # data = read_small_file('./sample_kmeans_data.txt')
    k= 5
    tempDist = 1.0
    convergeDist = 0.01
    init_centor = initClusterPoint(data=data,k=k)
    data_cluster = np.empty(data.shape[0])
    while tempDist > convergeDist:
        data_cluster,centers = getClusterPoint(data=data,k=k,centers=init_centor)
        # tempDist = np.sqrt(np.dot(np.power(init_centor-centers,2)))
        tempDist = computeDis(init_centor,centers).max()
        init_centor = centers
    draw(Data=data,centers=init_centor,index=data_cluster)
  • 绘图:
def draw(Data,centers,index=None):
    """
            特征数是15维时,可以降维展示分类效果,
            不代表实际数据的分布
    """
    plt.title('Kmeans classifier of Adult')
    fig,axes = plt.subplots(1,2)
    pca = decomposition.PCA(n_components=2)
    new_data = pca.fit_transform(Data)
    print (new_data)
    axes[0].scatter(new_data[:,0], new_data[:,1], marker='o',alpha=0.5)
    for center_index,center in enumerate(centers):
        data_index = new_data[index==center_index]
        axes[1].scatter(data_index[:,0], data_index[:,1], marker='o',color=next(palette),alpha=0.5)
    plt.show()
  • 效果如下:
最小python代码实现Kmeans算法并图形化展示_第1张图片
kmeans.png
  • 完整代码在:

你可能感兴趣的:(最小python代码实现Kmeans算法并图形化展示)