from numpy import *
import operator
from os import listdir
import matplotlib.pyplot as plt

"""程序清单2-1 K近邻算法"""
def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()
    classCount={}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def createDataSet():
    group = array([[1.0, 1.1],[1.0, 1.0],[0,0],[0,0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels

if __name__ == "__main__":
    group, labels = createDataSet()
    #print(group)
    """[[1.  1.1]
         [1.  1. ]
         [0.  0. ]
         [0.  0.1]]"""
    #print(labels)
    """['A', 'A', 'B', 'B']"""
    #图2-2
    # plt.scatter(group[:,0], group[:,1])
    # plt.show()
    #预测
    """[0, 0]是测试向量
       group 训练集
       labels 标签
       3 K值
    """
    pre = classify0([0, 0], group, labels, 3)
    print(pre)
    """B"""