- Kmeans算法简介:
Kmeans算法基本思想是初始随机给定K个簇中心,按照距离最近的原则把待分类的样本点分到各个簇,然后根据平均值计算新的簇的质心。一直迭代直到两次簇心之间的迭代距离小于要求的值。
-
基本步骤
- 未知簇心的数据集
- 初始化簇心
- 随机选取K个数据做为簇心
- 计算各个点到簇心的距离,并聚类到离该点最近的簇心上去
- 计算每一个簇类距离平均,并将这个平均值做为新的簇心
- 重复4
- 重复5
-
数据:
- 数据来源是CSIE,里面的adult数据,主要分类有:
age: continuous.
workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
fnlwgt: continuous.
education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
education-num: continuous.
marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
sex: Female, Male.
capital-gain: continuous.
capital-loss: continuous.
hours-per-week: continuous.
native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
- 需要对数据做mapping,这样才能计算距离:
convert_work= {
b'workclass':0.0,
b'Private':1.0,
b'Self-emp-not-inc':2.0,
b'Self-emp-inc':3.0,
b'Federal-gov':4.0,
b'Local-gov':5.0,
b'State-gov':6.0,
b'Without-pay':7.0,
b'Never-worked':8.0,
}
marry_convert = {
b"marital-status":1.0,
b"Married-civ-spouse":2.0,
b"Divorced":3.0,
b"Never-married":4.0,
b"Separated":5.0,
b"Widowed":6.0,
b"Married-spouse-absent":7.0,
b"Married-AF-spouse":8.0,
}
occupation_convert={
b"Tech-support":1.0,
b"Craft-repair":2.0,
b"Other-service":3.0,
b"Sales":4.0,
b"Exec-managerial":5.0,
b"Prof-specialty":6.0,
b"Handlers-cleaners":7.0,
b"Machine-op-inspct":8.0,
b"Adm-clerical":9.0,
b"Farming-fishing":10.0,
b"Transport-moving":11.0,
b"Priv-house-serv":12.0,
b"Protective-serv":13.0,
b"Armed-Forces":14.0,
}
relationship_convert= {
b"Wife":1.0,
b"Own-child":2.0,
b"Husband":3.0,
b"Not-in-family":4.0,
b"Other-relative":5.0,
b"Unmarried" :6.0,
}
race_convert = {
b"White":1.0,
b"Asian-Pac-Islander":2.0,
b"Amer-Indian-Eskimo":3.0,
b"Other":4.0,
b"Black":5.0,
}
sex_convert = {
b"Female":0,
b"Male" :1,
}
country_convert = {
b"United-States":1.0,
b"Cambodia":2.0,
b"England":3.0,
b"Puerto-Rico":4.0,
b"Canada":5.0,
b"Germany":6.0,
b"Outlying-US(Guam-USVI-etc)":7.0,
b"India":8.0,
b"Japan":9.0,
b"Greece":10.0,
b"South":11.0,
b"China":12.0,
b"Cuba":13.0,
b"Iran":14.0,
b"Honduras":15.0,
b"Philippines":16.0,
b"Italy":17.0,
b"Poland":18.0,
b"Jamaica":19.0,
b"Vietnam":20.0,
b"Mexico":21.0,
b"Portugal":22.0,
b"Ireland":23.0,
b"France":24.0,
b"Dominican-Republic":25.0,
b"Laos":26.0,
b"Ecuador":27.0,
b"Taiwan":28.0,
b"Haiti":29.0,
b"Columbia":30.0,
b"Hungary":31.0,
b"Guatemala":32.0,
b"Nicaragua":33.0,
b"Scotland":34.0,
b"Thailand":35.0,
b"Yugoslavia":36.0,
b"El-Salvador":37.0,
b"Trinadad&Tobago":38.0,
b"Peru":39.0,
b"Hong":40.0,
b"Holand-Netherlands.":41.0,
b"":1.0
}
class_convert = {
b">50K":0.0,
b"<=50K":1.0,
b"":1.0
}
convert_education={
b"Bachelors":0.0,
b"Some-college":1.0,
b"11th":2.0,
b"HS-grad":3.0,
b"Prof-school":4.0,
b"Assoc-acdm":5.0,
b"Assoc-voc":6.0,
b"9th":7.0,
b"7th-8th":8.0,
b"12th":9.0,
b"Masters":10.0,
b"1st-4th":11.0,
b"10th":12.0,
b"Doctorate":13.0,
b"5th-6th":14.0,
b"Preschool":15.0,
b"":1.0
}
- 读取文件:
def read_file(file,one_hot=False):
np_file = np.genfromtxt(fname=file,delimiter=',',\
replace_space='',
converters={
1:lambda x:convert_work.get(x,1.0), \
3:lambda x:convert_education.get(x,7.0), \
5:lambda x:marry_convert.get(x,4.0), \
6:lambda x:occupation_convert.get(x,3.0), \
7:lambda x:relationship_convert.get(x,5.0), \
8:lambda x:race_convert.get(x,4.0), \
9:lambda x:sex_convert.get(x,1.0), \
13:lambda x:country_convert.get(x,1.0), \
14:lambda x:class_convert.get(x,1.0)}
)
if one_hot:
pass
else:
return np_file
- 算法代码如下所示:
def computeDis(x,y):
sub_ = x -y
dist = np.sqrt(np.sum(np.power(sub_,2),axis=1))
return dist
def initClusterPoint(data,k):
##随机选取K个点做为初始聚类点
centers = data[np.random.randint(0,data.shape[0],k)]
return centers
def getClusterPoint(data,k,centers):
rows = data.shape[0]
all_distance = np.empty([rows,k])
for index,center in enumerate(centers):
distanc = computeDis(data,center)
all_distance[:,index] = distanc
print(centers)
small_dis = all_distance.min(axis=1)
small_dis_index = all_distance.argmin(axis=1)
##更新centers
new_centers = np.empty(centers.shape)
for index,center in enumerate(centers):
index_num = (small_dis_index == index).sum()
data_index_small = data[small_dis_index==index]
new_centers[index] = np.sum(data_index_small,axis=0) / index_num
return small_dis_index,new_centers
if __name__ == "__main__":
data = read_file('./adult.txt')
# data = read_small_file('./sample_kmeans_data.txt')
k= 5
tempDist = 1.0
convergeDist = 0.01
init_centor = initClusterPoint(data=data,k=k)
data_cluster = np.empty(data.shape[0])
while tempDist > convergeDist:
data_cluster,centers = getClusterPoint(data=data,k=k,centers=init_centor)
# tempDist = np.sqrt(np.dot(np.power(init_centor-centers,2)))
tempDist = computeDis(init_centor,centers).max()
init_centor = centers
draw(Data=data,centers=init_centor,index=data_cluster)
- 绘图:
def draw(Data,centers,index=None):
"""
特征数是15维时,可以降维展示分类效果,
不代表实际数据的分布
"""
plt.title('Kmeans classifier of Adult')
fig,axes = plt.subplots(1,2)
pca = decomposition.PCA(n_components=2)
new_data = pca.fit_transform(Data)
print (new_data)
axes[0].scatter(new_data[:,0], new_data[:,1], marker='o',alpha=0.5)
for center_index,center in enumerate(centers):
data_index = new_data[index==center_index]
axes[1].scatter(data_index[:,0], data_index[:,1], marker='o',color=next(palette),alpha=0.5)
plt.show()
- 效果如下:
- 完整代码在: