阅读本文前,建议查阅相关资料,了解 KNN 算法与 KD 树。
如图所示,假设一个点 a
目前的最近邻点为 b
,如果存在相对于 b
离 a
更近的点,那么这个点一定在以 a
为圆心,ab
为半径的圆内。
现右侧的区域是未知的,如果 a
到分界线的距离 l
大于目前的最近距离 L
(圆半径),则没有必要在右侧的未知区域继续寻找最近邻点(如图一),反之,则要继续寻找(如图二)。
相应的,投射到多维空间,假如切分边界为第 i
维,切分点的值为 v
(标量),当前最近邻点为 y
(向量),如果目标点 x
(向量) 到切分边界的距离 |x[i] - v| 满足以下关系
时,需要在另一侧继续搜索。
通常地,一个机器学习算法分为 fit
和 predict
两个阶段,基于线性搜索的 KNN
是一种惰性算法,它将全部的计算任务放到了 predict
阶段,predict
的时间复杂度为 O(n)
,KD 树之所以比线性搜索快,就是因为它将一部分任务放到了 fit
(建立 KD 树) 阶段,从而在搜索时可以略去大量不必搜索的结点(最优情况下时间复杂度为 O(1)
)。
上面说的比较简单,关于 KNN 算法和 KD 树的详细内容,请参考李航博士的《统计学习方法》。
我们给出部分关键性的代码。
double *data
表示,它的长度为 n_samples * n_features
,标签集也用一个一维数组 double *labels
表示,它的长度为 n_samples
。cpp
struct tree_node
{
size_t id; // 表示训练集中的第 i 个数据
size_t split; // 切分的维度
tree_node *left, *right; // 左、右子树
};
cpp
struct tree_model
{
tree_node *root; // 根结点
const double *datas; // X
const double *labels; // y
size_t n_samples; // 样例数
size_t n_features; // 每个样例的特征数
double p; // 距离度量
};
求 K-近邻时需要用到大顶堆,我们直接用 C++ 的优先队列来表示,堆内现有的 n(n <= k)
个近邻点中,距离测试点最远的在堆顶
struct neighbor_heap_cmp {
bool operator()(const std::tupledouble> &i,
const std::tupledouble> &j) {
return std::get<1>(i) < std::get<1>(j);
}
};
typedef std::tupledouble> neighbor;
typedef std::priority_queuestd::vector , neighbor_heap_cmp> neighbor_heap_;
neighbor_heap k_neighbor_heap_;
我们用类 KDTree
表示一个 KD 树类,它应该具有的功能有建树
和搜索
。
//(简化的代码,完整的代码详见最后)
class KDTree {
public:
// 建树
KDTree(const double *datas, const double *labels, size_t rows, size_t cols, double p)
// 返回树
tree_node *GetRoot() { return root; }
// 求一个测试点的 k 邻
std::vector<std::tupledouble >> FindKNearests(const double *coor, size_t k);
private:
tree_node *root_;
}
在建树之前,我们还要考虑如何选择切分维度和切分点。切分维度的选择有许多,一般的,可以取 dim = floor % n_features
,即当前树的层数对特征数取余,我们在这里使用 dim = argmax(nmax - nmin)
,即选取当前结点集合中极差最大的维度。
(这里是不完整的代码,有些工具函数的定义请详见完整源代码)
size_t KDTree::FindSplitDim(const std::vector &points) {
if (points.size() == 1)
return 0;
size_t cur_best_dim = 0;
double cur_largest_spread = -1;
double cur_min_val;
double cur_max_val;
for (size_t dim = 0; dim < n_features; ++dim) {
cur_min_val = GetDimVal(points[0], dim);
cur_max_val = GetDimVal(points[0], dim);
for (const auto &id : points) {
if (GetDimVal(id, dim) > cur_max_val)
cur_max_val = GetDimVal(id, dim);
else if (GetDimVal(id, dim) < cur_min_val)
cur_min_val = GetDimVal(id, dim);
}
if (cur_max_val - cur_min_val > cur_largest_spread) {
cur_largest_spread = cur_max_val - cur_min_val;
cur_best_dim = dim;
}
}
return cur_best_dim;
}
选择完切分维 k
之后,我们需选取当前结点集合中的结点在第 k
维的值的中位数 x
作为切分点的值,除去该点之外的点,第 k
维的值小于等于 x
的,放入左子树,反之放入右子树。
在求中位数时,不要全排序,然后取中间的点,可以采用类似快排的方法,找到中位数时就停止排序,这里我们就不写算法了,直接用 C++ 的函数。
std::tupledouble> KDTree::MidElement(const std::vector &points, size_t dim) {
size_t len = points.size();
for (size_t i = 0; i < points.size(); ++i)
get_mid_buf_[i] = std::make_tuple(points[i], GetDimVal(points[i], dim));
std::nth_element(get_mid_buf_,
get_mid_buf_ + len / 2,
get_mid_buf_ + len,
[](const std::tupledouble> &i, const std::tupledouble> &j) {
return std::get<1>(i) < std::get<1>(j);
});
return get_mid_buf_[len / 2];
}
建树直接按照建立二叉树的方法即可
tree_node *KDTree::BuildTree(const std::vector &points) {
size_t dim = FindSplitDim(points);
std::tupledouble> t = MidElement(points, dim);
size_t arg_mid_val = std::get<0>(t);
double mid_val = std::get<1>(t);
tree_node *node = Malloc(tree_node, 1);
node->left = nullptr;
node->right = nullptr;
node->id = arg_mid_val;
node->split = dim;
std::vector left, right;
for (auto &i : points) {
if (i == arg_mid_val)
continue;
if (GetDimVal(i, dim) <= mid_val)
left.emplace_back(i);
else
right.emplace_back(i);
}
if (!left.empty())
node->left = BuildTree(left);
if (!right.empty())
node->right = BuildTree(right);
return node;
}
一般书上所讲的都是搜索最近邻,但是我们这里是搜索 K-近邻,需要对书上的算法做少许的扩充。
搜索最近邻时,我们一般设置两个变量 cur_min_id
和 cur_min_dist
,如果当前搜索到的点到测试点的距离 l < cur_min_dist
时,我们将上述两个变量更新为新点的 id
和 dist
。
相应的,在搜索 K-近邻时,我们可以设置一个最多有 k
个元素的大顶堆,这样,在搜索时,当堆满时,只需比较当前搜索点的 dist
是否小于堆顶点的 dist
,如果小于,堆顶出堆,并将当前搜索点压入,反之,则不变;当堆未满时,直接将该搜索点压入。
我们直接使用二叉树深度优先遍历的非递归算法(具体的描述详见《统计学习方法》第 43 页算法 3.3)。
std::vector<std::tupledouble >> KDTree::FindKNearests(const double *coor, size_t k) {
std::memset(visited_buf_, 0, sizeof(bool) * n_samples);
std::stack paths;
tree_node *p = root;
while (p) {
HeapStackPush(paths, p, coor, k);
p = coor[p->split] <= GetDimVal(p->id, p->split) ? p = p->left : p = p->right;
}
while (!paths.empty()) {
p = paths.top();
paths.pop();
if (!p->left && !p->right)
continue;
if (k_neighbor_heap_.size() < k) {
if (p->left)
HeapStackPush(paths, p->left, coor, k);
if (p->right)
HeapStackPush(paths, p->right, coor, k);
} else {
double node_split_val = GetDimVal(p->id, p->split);
double coor_split_val = coor[p->split];
double heap_top_val = std::get<1>(k_neighbor_heap_.top());
if (coor_split_val > node_split_val) {
if (p->right)
HeapStackPush(paths, p->right, coor, k);
if ((coor_split_val - node_split_val) < heap_top_val && p->left)
HeapStackPush(paths, p->left, coor, k);
} else {
if (p->left)
HeapStackPush(paths, p->left, coor, k);
if ((node_split_val - coor_split_val) < heap_top_val && p->right)
HeapStackPush(paths, p->right, coor, k);
}
}
}
std::vector<std::tupledouble >> res;
while (!k_neighbor_heap_.empty()) {
res.emplace_back(k_neighbor_heap_.top());
k_neighbor_heap_.pop();
}
return res;
}
详见 https://github.com/WiseDoge/libkdtree
完整代码中除了 KD-Tree 的代码外,还给出了测试代码和 Python 接口代码,以及一些调用第三方库来加速的手段。