调用库使用KMeans算法对各省份消费水平进行分类
全国各省消费数据如下,本文写程序中利用loadData()函数从data.csv文件读取。
北京,2959.19,730.79,749.41,513.34,467.87,1141.82,478.42,457.64
天津,2459.77,495.47,697.33,302.87,284.19,735.97,570.84,305.08
河北,1495.63,515.90,362.37,285.32,272.95,540.58,364.91,188.63
山西,1406.33,477.77,290.15,208.57,201.50,414.72,281.84,212.10
内蒙古,1303.97,524.29,254.83,192.17,249.81,463.09,287.87,192.96
辽宁,1730.84,553.90,246.91,279.81,239.18,445.20,330.24,163.86
吉林,1561.86,492.42,200.49,218.36,220.69,459.62,360.48,147.76
黑龙江,1410.11,510.71,211.88,277.11,224.65,376.82,317.61,152.85
上海,3712.31,550.74,893.37,346.93,527.00,1034.98,720.33,462.03
江苏,2207.58,449.37,572.40,211.92,302.09,585.23,429.77,252.54
浙江,2629.16,557.32,689.73,435.69,514.66,795.87,575.76,323.36
安徽,1844.78,430.29,271.28,126.33,250.56,513.18,314.00,151.39
福建,2709.46,428.11,334.12,160.77,405.14,461.67,535.13,232.29
江西,1563.78,303.65,233.81,107.90,209.70,393.99,509.39,160.12
山东,1675.75,613.32,550.71,219.79,272.59,599.43,371.62,211.84
河南,1427.65,431.79,288.55,208.14,217.00,337.76,421.31,165.32
湖北,1783.43,511.88,282.84,201.01,237.60,617.74,523.52,182.52
湖南,1942.23,512.27,401.39,206.06,321.29,697.22,492.60,226.45
广东,3055.17,353.23,564.56,356.27,811.88,873.06,1082.82,420.81
广西,2033.87,300.82,338.65,157.78,329.06,621.74,587.02,218.27
海南,2057.86,186.44,202.72,171.79,329.65,477.17,312.93,279.19
重庆,2303.29,589.99,516.21,236.55,403.92,730.05,438.41,225.80
四川,1974.28,507.76,344.79,203.21,240.24,575.10,430.36,223.46
贵州,1673.82,437.75,461.61,153.32,254.66,445.59,346.11,191.48
云南,2194.25,537.01,369.07,249.54,290.84,561.91,407.70,330.95
西藏,2646.61,839.70,204.44,209.11,379.30,371.04,269.59,389.33
陕西,1472.95,390.89,447.95,259.51,230.61,490.90,469.10,191.34
甘肃,1525.57,472.98,328.90,219.86,206.65,449.69,249.66,228.19
青海,1654.69,437.77,258.78,303.00,244.93,479.53,288.56,236.51
宁夏,1375.46,480.89,273.84,317.32,251.08,424.75,228.73,195.93
新疆,1608.82,536.05,432.46,235.82,250.28,541.30,344.85,214.40
实现步骤
- 加载数据
- 调用fit_predict()方法对数据归类
- 获取各省份所属类的标记
- 将同一个标记的省放在同一个列表
- 输出分类结果
需要注意的是:
- 由于data.csv数据源文件有中文而且是utf-8编码所以打开文件编码也要是utf-8,然后python程序设置成utf-8,将分类结果(省份名称)写入到文件中也需要是utf-8编码
- Numpy.average(arr,axis)要根据需求设置axis,axis=1表示按行计算,axis=2表示按列计算,不写就是将所有数求平均
# coding:utf-8
import numpy as np
from sklearn.cluster import KMeans
def loadData(filePath):
"""
读取文件数据并返回消费数据和对应省份名称
:param filePath: 数据文件路径
:return: 各省消费数据,省份名称
"""
file = open(filePath, 'r+', encoding='utf-8') # 注意读文件的编码
lines = file.readlines()
fileData = []
fileCityName = []
for line in lines:
items = line.strip().split(',')
fileCityName.append(items[0])
fileData.append([float(items[i]) for i in range(1, len(items))])
file.close()
return fileData, fileCityName
def saveData(filePath, data):
"""
保存输出结果到指定路径下
:param filePath: 保存结果的目的文件路径
:param data: 结果数据
:return:
"""
file = open(filePath, 'w+',encoding='utf-8') # 注意编码
file.write(str(data))
file.close()
data, cityName = loadData('data.csv')
km = KMeans(n_clusters=3) # 将省份分3类
label = km.fit_predict(data) # 获取各省份所属的类编号
avgExpenses = np.average(km.cluster_centers_, axis=1) # axis 1按行 2按列 求平均
# 根据label将相同分类省份名放置一起
CityCluster = [[], [], []]
for i in range(len(cityName)):
CityCluster[label[i]].append(cityName[i])
resultStr = '' # 保存分类结果
# 输出分类结果
for i in range(len(CityCluster)):
print("平均消费%0.2f" % (avgExpenses[i]))
print(CityCluster[i])
# 将同分类省份用,拼接
resultStr = resultStr + ','.join(CityCluster[i]) + '\n'
# 保存分类结果
saveData('result.csv',resultStr)
自己实现Kmeans算法
kmeans步骤
- 加载数据,设置要分成k类
- 生成k个随机簇中心
- 遍历所有点找到离该点最近的簇中心
- 然后对所有簇进行求平均然后更新当前簇中心为平均值
from numpy import *
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
arrLine = list(map(float, curLine))
dataMat.append(arrLine)
fr.close()
return mat(dataMat)
def distEclud(vecA, vecB):
return sqrt(sum(power(vecA - vecB, 2))) # sqrt((x1-x2)^2+...)
def randCent(dataSet, k):
n = shape(dataSet)[1]
centroids = mat(zeros((k, n))) # 用于保存随机生成的k个簇中心
for j in range(n): # 按列依次生成随机数
minJ = min(dataSet[:, j]) # 获取每列最小值
rangeJ = float(max(dataSet[:, j]) - minJ)
centroids[:, j] = minJ + rangeJ * random.rand(k, 1) # 生成k行1列的随机数
return centroids
def kMeans(dataSet, k, distMeans=distEclud, createCent=randCent):
m = shape(dataSet)[0]
clusterAssment = mat(zeros((m, 2))) # 各行所属簇编号 离最近簇中心距离
centroids = createCent(dataSet, k) # 保存k个簇中心
clusterChanged = True
while clusterChanged:
clusterChanged = False
# 对每个点去找最近的簇中心
for i in range(m):
minDist = inf
minIndex = -1
# 从所有簇中心找离当前点最近的簇中心
for j in range(k):
distJI = distMeans(centroids[j, :], dataSet[i, :])
if distJI < minDist:
minDist = distJI
minIndex = j
if clusterAssment[i, 0] != minIndex:
clusterChanged = True
clusterAssment[i, :] = minIndex, minDist
for cent in range(k):
# 获取cent簇所有点
lines = nonzero(clusterAssment[:, 0].A == cent)[0] # 找簇等于cent的点对应的行号
ptsInClust = dataSet[lines]
# 对当前簇所有点求平均
centroids[cent, :] = mean(ptsInClust, axis=0)
return centroids, clusterAssment
myMat = loadDataSet("data.txt")
clusterNum = 3
centroids, clusterAssment = kMeans(myMat, clusterNum)
print('簇中心\n',centroids)
for label in range(clusterNum):
print(label,"\n",myMat[nonzero(clusterAssment[:,0]==label)[0]])
data.txt数据
3.792121 5.135768
-4.786473 3.358547
2.624081 -3.260715
-4.009299 -2.978115
2.493525 1.963710
-2.513661 2.642162
1.864375 -3.176309
-3.171184 -3.572452
2.894220 2.489128
-2.562539 2.884438
3.491078 -3.947487
-2.565729 -2.012114
3.332948 3.983102
-1.616805 3.573188
2.280615 -2.559444
-2.651229 -3.103198
2.321395 3.154987
-1.685703 2.939697
3.031012 -3.620252
-4.599622 -2.185829
4.196223 1.126677
-2.133863 3.093686
4.668892 -2.562705
-2.793241 -2.149706
2.884105 3.043438
-2.967647 2.848696
4.479332 -1.764772
-4.905566 -2.911070
输出结果:
簇中心
[[-3.06875436 0.17342357]
[ 2.99334217 -3.18781867]
[ 3.29923363 2.39150475]]
0
[[-4.786473 3.358547]
[-4.009299 -2.978115]
[-2.513661 2.642162]
[-3.171184 -3.572452]
[-2.562539 2.884438]
[-2.565729 -2.012114]
[-1.616805 3.573188]
[-2.651229 -3.103198]
[-1.685703 2.939697]
[-4.599622 -2.185829]
[-2.133863 3.093686]
[-2.793241 -2.149706]
[-2.967647 2.848696]
[-4.905566 -2.91107 ]]
1
[[ 2.624081 -3.260715]
[ 1.864375 -3.176309]
[ 3.491078 -3.947487]
[ 2.280615 -2.559444]
[ 3.031012 -3.620252]
[ 4.668892 -2.562705]]
2
[[ 3.792121 5.135768]
[ 2.493525 1.96371 ]
[ 2.89422 2.489128]
[ 3.332948 3.983102]
[ 2.321395 3.154987]
[ 4.196223 1.126677]
[ 2.884105 3.043438]
[ 4.479332 -1.764772]]