kNN算法代码实现简单分析2之——数据导入及画图

代码参照《机器学习实战》

现在就数据导入和画图进行分析:

import numpy as np
import matplotlib.pyplot as plt
import kNNtest1 as kNN

def file2matrix(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())
    returnMat = np.zeros((numberOfLines,3))
    classLabelVector = []
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1

    return returnMat,classLabelVector

if __name__ == '__main__':
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    #print(datingDataMat)
    #print(datingLabels)

    #图像显示
    # ind1 = []
    # ind2 = []
    # ind3 = []
    # nm = len(datingLabels)
    # for i in range(nm):
    #     if datingLabels[i] == 1:
    #         ind1.append(i)
    #     if datingLabels[i] == 2:
    #         ind2.append(i)
    #     if datingLabels[i] == 3:
    #         ind3.append(i)

    ind1 = np.where(np.array(datingLabels)==1)
    ind2 = np.where(np.array(datingLabels)==2)
    ind3 = np.where(np.array(datingLabels)==3)


    fig = plt.figure(figsize=(9,6))
    ax = fig.add_subplot(111)
    plt.xlabel('percentage of time spent playing video games')
    plt.ylabel('liters of cream consumed per week')
    plt.title('Games VS cream')


    ax.scatter(datingDataMat[ind1,1], datingDataMat[ind1,2], c="red",label="Did Not Like" ,s=20)
    ax.scatter(datingDataMat[ind2,1], datingDataMat[ind2,2], c="blue",label="Liked in small dose" ,s=20)
    ax.scatter(datingDataMat[ind3,1], datingDataMat[ind3,2], c="green",label="Liked in large dose" ,s=20)

    ax.legend(loc='best')
    plt.show()

结果是:

kNN算法代码实现简单分析2之——数据导入及画图_第1张图片

代码分析:

1. fr.readlines()

读写文件要生成文件流(filestream),fr就是文件流对象,readlines()生成的是一个迭代对象,它自动将文本按照行分析成一个列表,一行就是列表的一个元素,但要注意,readlines不是列表,是类列表的迭代对象,可通过迭代获得内部元素(像这样for i in fr.readlines():),这样可以节省内存,避免像read()函数那样,一次性将文本读入到内存。而readline()函数则只是按行读数据,运行一次,读出一行。

len(fr.readlines())可以获得文本的行数,这样也确定了数据的个数


2.  numpy.zeros(shape())
returnMat = np.zeros((numberOfLines,3))

知道了数据的维度,那么就需要构造一个相同维度的数组来进行存储,numpy中的zeros()数组就是用来构建这样的数组外壳的。当刚开始学习时,直接将zeros知道怎么构建,但是不知道构建一个元素为0的数组有啥意义,现在知道了吧(可能只是我,刚开始学的时候是不知道的,嘿嘿),来个例子:

>>> np.zeros((3,4))
array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])

3. 文本读入的处理:

        line = line.strip()
        listFromLine = line.split('\t')

先strip()将前后的空格进行strip off,然后将该行的数据进行分割,因为该行数据原来就是分成几列的,但是文本读取时,编程一个字符串了,所以要分割,重新划分行列。

如果我们打开“datingTestSet2.txt”可以看到它是以tab键进行分割的,所以这里我们用“\t”

如果是以空格键进行分割的,用line.split(' ')


4. 文本读取是的结果字符串类型,要进行数值转化
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))

文本中数据都是数值类型,但是文件流读得结果是字符类型,对于numpy库中的对象returnMat,listFromLine[0:3]中的字符串类型存入到returnMat中,会自动进行数值转化(前提是这些字符串是数字),但是python中的列表list不行,如果想要在classLabelVector中append的是数值型元素,必须强制转换,如果像这样:

classLabelVector.append(int(listFromLine[-1]))

则结果是classLabelVector的元素是字符串。


5. where函数:寻找array数组中某个特定值对应
ind1 = np.where(np.array(datingLabels)==1)
我们想查找datingLabels列表中元素是1的索引值,可以用where函数,但是必须先让该列表转换成array数组。这样效率高,简洁,比我用判断进行查询 (代码中被#注释掉的部分)简洁得多。我知道了 datingLabels中元素是1(label)的索引,我也就知道了returnMat数组中标签(label)是1的索引。注意,返回的是对应值的索引构成的列表。

6. array数组提取元素的方法:

普通的方法,我们都知道,但是还有一个方法,就是不是连续的索引,及间隔不一致的索引构成的列表,就是我们在第5小节用where函数得到的索引列表ind1,结果是同样可以提取的:

datingDataMat[ind1,1]

结果就是返回datingDataMat第2列中的索引构成的一维数组,这里的索引值是ind1的元素。例子:

>>> a = np.random.random((5,4))
>>> a
array([[0.35403348, 0.32452352, 0.54559416, 0.33362996],
       [0.93168307, 0.79195363, 0.16087671, 0.4803462 ],
       [0.34011118, 0.479057  , 0.20447057, 0.90128244],
       [0.05174083, 0.28380071, 0.86293482, 0.30875988],
       [0.53247842, 0.37601026, 0.27587104, 0.86953997]])
>>> b = a[[0,3],1]
>>> b
array([0.32452352, 0.28380071])

这个很有用啊!这是因为我想通过标签将一个大的数组分成几个的数组,各个数组内部元素标签相同,各个数组标签不同,这样在画散点图时,就可以得到legend(图例)。

7. 讲述代码中画散点图的知识:

画图用到matplotlib库,import matplotlib.plot as plt

画图先要构建画板:

fig = plt.figure(figsize=(9,6))

构建了画板,大小是figsize=(9,6),这里我们必须写上figsize,可知,figsize是一个缺省值,即含有默认值,所以不写figsize=(9,6),就是构建了默认大小的画板

ax = fig.add_subplot(111)

add_subplot 构建了子图,子图是在画板里面的,这里的(111)可以写成(1,1,1),意思是将fig画板分成了1行1列,本子图的位置在第1位置。

ax.scatter(datingDataMat[ind1,1], datingDataMat[ind1,2], c="red",label="Did Not Like" ,s=20)

ax.scatter()是画散点图用的函数,前两个参数分别是横坐标和纵坐标,后面参数是默认值,c、label、s、mark分别第颜色、标签、大小、标志(默认是原点显示)。

注:mark可以自定义:mark=‘+’、mark=‘x’、mark=‘o’等

ax.legend(loc='best')

通过该命令可以让图显示出图例,参数loc='best',是由程序自己选择最合适的位置,也可以自定义其他位置:

loc=1、loc=2、loc=3、loc=4、loc=‘upper right’等



你可能感兴趣的:(numpy,kNN)