Python 纯手写 实现KNN KD树算法

《统计学习方法》第三章中的算法3.2与算法3.3

实现KD树的构造级最邻近点搜索

算法3.2:

构造KD树

输入:空间数据集T

输出:KD树

Python代码:

def creat_ketree(data,depth=0):
    """
    创建KD树
    axis为坐标轴
    num为中位数节点
    depth为树的深度
    median为父节点
    left,right分别为左右孩子节点
    :param data:
    :param depth:
    :return:
    """
    try:
        m = len(data[0])
    except IndexError as e:
        return None

    tree_node = {}
    axis = depth % m
    depth += 1
    tree_node['split'] = axis
    data = sorted(data,key = lambda data : data[axis])
    num = len(data)
    tree_node['median'] = data[num]
    tree_node['left'] = creat_ketree(data[:num] , depth)
    tree_node['right'] = creat_ketree(data[num+1:] , depth)
    return tree_node

 

 

算法3.3:

搜索KD树

输入:已构造好的KD树,目标点x

输出:x的最近邻

Python代码:


def Euclidean_distance(A,B):
    """
    计算欧式距离
    :param A:
    :param B:
    :return:
    """
    sum_distance = 0
    for i in range(len(A)):
        sum_distance += pow(abs(A[i] - B[i]) , 2)
    return math.sqrt(sum_distance)


def search_tree(tree,data):
    """
    首先找到距离实例点data最近的叶子节点
    后开始遍历kd树
    :param tree:
    :param data:
    :return:
    """
    k = len(data)
    if tree is None:
        return [0]*k , float('inf')
    else:
        median_point = tree['median']                                                      #找到距离实例点data最近的叶子节点
        node_axis = tree['split']
        if data[node_axis] > median_point[node_axis]:
            nearest_point , nearest_distance = search_tree(tree['right'],data)
        else:
            nearest_point , nearest_distance = search_tree(tree['left'] , data)
        now_distance = Euclidean_distance(data , median_point)                             #计算当前节点与实例点data的距离
        if nearest_distance > now_distance:                                                #若当前节点距离小于最小距离,更新最小距离和最斤实例点
            nearest_point = median_point.copy()
            nearest_distance = now_distance
        if now_distance < abs(median_point[node_axis] - data[node_axis]):                  #否则计算到父节点axis轴距离
            return nearest_point , nearest_distance
        else:
            if median_point[node_axis] <= data[node_axis]:
                nearer_point , nearer_distance = search_tree(tree['left'] , data)
            else:
                nearer_point , nearer_distance = search_tree(tree['right'] , data)
            if nearer_distance < nearest_distance:
                nearest_distance = nearer_distance
                nearest_point = nearer_point.copy()

            return nearest_point , nearest_distance

 

你可能感兴趣的:(机器学习)