代码参照《机器学习实战》
现在就数据导入和画图进行分析:
import numpy as np
import matplotlib.pyplot as plt
import kNNtest1 as kNN
def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines())
returnMat = np.zeros((numberOfLines,3))
classLabelVector = []
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector
if __name__ == '__main__':
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
#print(datingDataMat)
#print(datingLabels)
#图像显示
# ind1 = []
# ind2 = []
# ind3 = []
# nm = len(datingLabels)
# for i in range(nm):
# if datingLabels[i] == 1:
# ind1.append(i)
# if datingLabels[i] == 2:
# ind2.append(i)
# if datingLabels[i] == 3:
# ind3.append(i)
ind1 = np.where(np.array(datingLabels)==1)
ind2 = np.where(np.array(datingLabels)==2)
ind3 = np.where(np.array(datingLabels)==3)
fig = plt.figure(figsize=(9,6))
ax = fig.add_subplot(111)
plt.xlabel('percentage of time spent playing video games')
plt.ylabel('liters of cream consumed per week')
plt.title('Games VS cream')
ax.scatter(datingDataMat[ind1,1], datingDataMat[ind1,2], c="red",label="Did Not Like" ,s=20)
ax.scatter(datingDataMat[ind2,1], datingDataMat[ind2,2], c="blue",label="Liked in small dose" ,s=20)
ax.scatter(datingDataMat[ind3,1], datingDataMat[ind3,2], c="green",label="Liked in large dose" ,s=20)
ax.legend(loc='best')
plt.show()
结果是:
代码分析:
读写文件要生成文件流(filestream),fr就是文件流对象,readlines()生成的是一个迭代对象,它自动将文本按照行分析成一个列表,一行就是列表的一个元素,但要注意,readlines不是列表,是类列表的迭代对象,可通过迭代获得内部元素(像这样for i in fr.readlines():),这样可以节省内存,避免像read()函数那样,一次性将文本读入到内存。而readline()函数则只是按行读数据,运行一次,读出一行。
len(fr.readlines())可以获得文本的行数,这样也确定了数据的个数
returnMat = np.zeros((numberOfLines,3))
知道了数据的维度,那么就需要构造一个相同维度的数组来进行存储,numpy中的zeros()数组就是用来构建这样的数组外壳的。当刚开始学习时,直接将zeros知道怎么构建,但是不知道构建一个元素为0的数组有啥意义,现在知道了吧(可能只是我,刚开始学的时候是不知道的,嘿嘿),来个例子:
>>> np.zeros((3,4))
array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
3. 文本读入的处理:
line = line.strip()
listFromLine = line.split('\t')
先strip()将前后的空格进行strip off,然后将该行的数据进行分割,因为该行数据原来就是分成几列的,但是文本读取时,编程一个字符串了,所以要分割,重新划分行列。
如果我们打开“datingTestSet2.txt”可以看到它是以tab键进行分割的,所以这里我们用“\t”
如果是以空格键进行分割的,用line.split(' ')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
文本中数据都是数值类型,但是文件流读得结果是字符类型,对于numpy库中的对象returnMat,listFromLine[0:3]中的字符串类型存入到returnMat中,会自动进行数值转化(前提是这些字符串是数字),但是python中的列表list不行,如果想要在classLabelVector中append的是数值型元素,必须强制转换,如果像这样:
classLabelVector.append(int(listFromLine[-1]))
则结果是classLabelVector的元素是字符串。
ind1 = np.where(np.array(datingLabels)==1)
我们想查找datingLabels列表中元素是1的索引值,可以用where函数,但是必须先让该列表转换成array数组。这样效率高,简洁,比我用判断进行查询
(代码中被#注释掉的部分)简洁得多。我知道了
datingLabels中元素是1(label)的索引,我也就知道了returnMat数组中标签(label)是1的索引。注意,返回的是对应值的索引构成的列表。
普通的方法,我们都知道,但是还有一个方法,就是不是连续的索引,及间隔不一致的索引构成的列表,就是我们在第5小节用where函数得到的索引列表ind1,结果是同样可以提取的:
datingDataMat[ind1,1]
结果就是返回datingDataMat第2列中的索引构成的一维数组,这里的索引值是ind1的元素。例子:
>>> a = np.random.random((5,4))
>>> a
array([[0.35403348, 0.32452352, 0.54559416, 0.33362996],
[0.93168307, 0.79195363, 0.16087671, 0.4803462 ],
[0.34011118, 0.479057 , 0.20447057, 0.90128244],
[0.05174083, 0.28380071, 0.86293482, 0.30875988],
[0.53247842, 0.37601026, 0.27587104, 0.86953997]])
>>> b = a[[0,3],1]
>>> b
array([0.32452352, 0.28380071])
这个很有用啊!这是因为我想通过标签将一个大的数组分成几个的数组,各个数组内部元素标签相同,各个数组标签不同,这样在画散点图时,就可以得到legend(图例)。
画图用到matplotlib库,import matplotlib.plot as plt
画图先要构建画板:
fig = plt.figure(figsize=(9,6))
构建了画板,大小是figsize=(9,6),这里我们必须写上figsize,可知,figsize是一个缺省值,即含有默认值,所以不写figsize=(9,6),就是构建了默认大小的画板
ax = fig.add_subplot(111)
add_subplot 构建了子图,子图是在画板里面的,这里的(111)可以写成(1,1,1),意思是将fig画板分成了1行1列,本子图的位置在第1位置。
ax.scatter(datingDataMat[ind1,1], datingDataMat[ind1,2], c="red",label="Did Not Like" ,s=20)
ax.scatter()是画散点图用的函数,前两个参数分别是横坐标和纵坐标,后面参数是默认值,c、label、s、mark分别第颜色、标签、大小、标志(默认是原点显示)。
注:mark可以自定义:mark=‘+’、mark=‘x’、mark=‘o’等
ax.legend(loc='best')
通过该命令可以让图显示出图例,参数loc='best',是由程序自己选择最合适的位置,也可以自定义其他位置:
loc=1、loc=2、loc=3、loc=4、loc=‘upper right’等