机器学习一:邻近算法【K最近邻(KNN,k-NearestNeighbor)分类算法】python代码实现KNN

理论内容请参看博客:https://blog.csdn.net/weixin_41676798/article/details/90454618

"""
数据分类-knn算法:
1)计算测试数据与各个训练数据之间的距离;
2)按照距离的递增关系进行排序;
3)选取距离最小的K个点;
4)确定前K个点所在类别的出现频率;
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。
本代码场景:根据打斗次数与接吻次数预测电影属于动作片还是爱情片
"""

import numpy as np
import operator
import pandas as pd


# a = np.array([1, 5, 6, 4])
# np.tile(a, 2)       # 在列的方向上扩展:在列的方向上扩展2次,即4列变为8列。[1,5,6,4,1,5,6,4]
# np.tile(a, (2, 1))  # 在行的方向上扩展:在行的方向上扩展2次,列的方向上扩展1次,即1行变为2行。[[1,5,6,4],[1,5,6,4]]

"""
函数说明:KNN算法
Parameters:
    k - 分类
    testdata - 测试集
    traindata - 训练集
    lables - 分类标签
Returns:
    sortcount[0][0] - 分类结果
"""


def knn(k, testdata, traindata, lables):
    traindatasize = traindata.shape[0]  # 获取训练集的行数
    dif = np.tile(testdata, (traindatasize, 1)) - traindata  # 测试集在行方向上扩展traindatasize次,在列方向上扩展1次
    sqdif = dif**2  # 特征相减后平方
    sumsqdif = sqdif.sum(axis=1)  # axis=0对每一列求和,axis=1对每一行求和
    distance = sumsqdif**0.5  # 开方,计算距离
    sortdistance = distance.argsort()  # 对元素排序,得到升序排序的索引
    count = {}  # 用于统计每个类别出现多少次
    for i in range(k):
        vote = lables[sortdistance[i]]  # 取出第i个的类别
        # get()方法,返回指定键vote的值,如果vote值不在字典中返回默认值,计算类别出现次数。vote每出现一次就统计一次
        count[vote] = count.get(vote, 0) + 1
    #  key=operator.itemgetter(1):按照count的值降序排序;key=operator.itemgetter(0):按照count的键降序排序
    sortcount = sorted(count.items(), key=operator.itemgetter(1), reverse=True)
    return sortcount[0][0]  # 返回测试数据的类别


"""
函数说明:创建训练集
Parameters:
    无
Returns:
    group - 数据集
    labels - 分类标签
"""


def createtestdata():
    group = np.array([[1, 101], [5, 89], [108, 5], [115, 8]])
    labels = ['love movie', 'love movie', 'action movie', 'action movie']
    # 若训练集在excel文档中,请注释上方两行代码并运行下方三行代码
    # 若excel与本代码在不同文件夹,需在文件名称前加路径,如:data = pd.read_excel('C:\\Users\\asus\\Desktop\\knn_predict_movie_type.xlsx')
    # data = pd.read_excel('knn_predict_movie_type.xlsx')
    # group = np.array(data[['dadou', 'jiewen']])
    # labels = list(data['leixing'])  # 或者:labels = data['leixing'].tolist()

    return group, labels


if __name__ == '__main__':
    group, labels = createtestdata()
    print group
    print labels
    test = [101, 20]
    test_class = knn(3, test, group, labels)
    print test_class
    print "run success"

你可能感兴趣的:(机器学习)