kd树代码

目录

  • sklearn库实现kd树
  • sklearn库kd树API介绍

sklearn库实现kd树

import numpy as np
from sklearn.neighbors import KDTree
x_data = np.array([[2, 3],
                   [5, 4],
                   [9, 6],
                   [4, 7],
                   [8, 1],
                   [7, 2]])

tree = KDTree(x_data, leaf_size=2) # 初始化,leaf_size=2是叶子节点为2,二叉树
dist, index = tree.query([[5.1, 4.2]], k=2,)
print(ind)
print(dist)

# 使用画圆,确定范围
indices = tree.query_radius([[5.1, 4.2]], r=2)
print(indices)


sklearn库kd树API介绍

from sklearn.neighbors import KDTree

x_data = np.array([[2, 3],
                   [5, 4],
                   [9, 6],
                   [4, 7],
                   [8, 1],
                   [7, 2]])
tree = KDTree(x_data, leaf_size=2)# 初始化类

"""
类中参数介绍:
	X, 输入的数据二维,形状(n_samples,n_features)
	leaf_size=40, 叶子节点树,设置leaf_size=2,相当于二叉树
	metric='minkowski', 距离 闵可夫斯基,就是KNN距离给出的那个公式,下面是。
		tree.valid_metrics可以查看距离
		['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock','l1', 'chebyshev', 'infinity']

方法:
	1.查询k近邻的
		query(X,k=1,return_distance=True,dualtree=False,breadth_first=False)
		参数介绍:
			X:查询的数据,二维数组
			k=1:返回的k近邻的数量
			return_distance=True,是否返回距离
			dualtree=False,是否使用双树算法,双树算法可以针对较大的N具有更好的缩放比例
			breadth_first=False  广度优先还是深度优先,默认使用深度优先
	2.查询给定半径的
		query_radius( X, r, count_only=False)
		参数介绍:
			X:二维输入数据
			r:给定的半径
			count_only=False, 是否仅返回半径内点的数量
	3.高斯和密度估计,使用内核计算点X的密度估计
		kernel_density(X, h, kernel='gaussian', atol=0, rtol=1, )
		参数介绍:
			X:二维输入数据
			h:内核的带宽
			kernel='gaussian':使用的内核,可选的内核['gaussian','tophat','epanechnikov','exponential''linear', 'cosine']
			atol=0:
			rtol=1:
			breadth_first:深度还是广度优先
	4.计算两点自相关函数
		two_point_correlation(X,r,dualtree=False)
		参数介绍:
			X:输入二维数组
			r:一维距离数组
			dualtree=False:是否使用双树算法
	
	
"""
# query方法查询,K近邻查询
dist, ind = tree.query([[5.1, 4.2]], k=2,)
print(ind)
print(dist)

# 使用半径的查询方法,半径维0.3,只返回个数
tree.query_radius([[5.1, 4.2]], r=0.3, count_only=True))

# 使用高斯核的方法
tree.kernel_density([[5.1, 4.2]], h=0.1, kernel='gaussian')

# 使用相关系数
r = np.linspace(0, 1, 5)
tree.two_point_correlation([[5.1, 4.2]], r)

你可能感兴趣的:(机器学习)