球树结构
sklearn.neighbors.BallTree(
X,
leaf_size=40,
metric='minkowski',
**kwargs)
X |
|
leaf_size |
|
metric | 度量距离 |
import numpy as np
from sklearn.neighbors import BallTree
X = np.random.randn(10,3)
tree = BallTree(X, leaf_size=2)
dist, ind = tree.query(X[:2], k=3)
print(ind)
# 最近的k个邻居的index
'''
[[0 4 5]
[1 2 8]]
'''
print(dist)
# 最近的k个邻居的距离
'''
[[0. 0.86677441 1.16406937]
[0. 0.95190704 1.32997164]]
'''
query_radius(
X,
r,
return_distance=False,
count_only=False,
sort_results=False)
X |
|
r |
float 或一维数组,表示查询半径 |
count_only |
bool,默认为 False。 如果为 True,则只返回每个查询点内邻居点的数量,而不返回邻居点的索引 |
return_distance |
bool,默认为 False。如果为 True,则返回每个查询点到其邻居点的距离列表 |
import numpy as np
from sklearn.neighbors import BallTree
X = np.random.randn(10,3)
tree = BallTree(X, leaf_size=2)
tree.query_radius(X[:2],r=[0.1,5])
'''
array([array([0], dtype=int64),
array([6, 2, 4, 7, 8, 1, 0, 3, 9, 5], dtype=int64)], dtype=object)
'''
import numpy as np
from sklearn.neighbors import BallTree
X = np.random.randn(10,3)
tree = BallTree(X, leaf_size=2)
tree.query_radius(X[:2],r=[0.1,5],return_distance=True)
'''
(array([array([0], dtype=int64),
array([8, 2, 3, 6, 1, 9, 7, 0, 5, 4], dtype=int64)], dtype=object),
array([array([0.]),
array([2.18948629, 1.05002031, 1.48036256, 1.54854719, 0. ,
2.37799982, 3.36371823, 2.63138373, 2.54630893, 3.57322436])],
dtype=object))
'''