我的朋友海伦一直使用在线约会网站寻找适合自己的约会对象。尽管约会网站会推荐不同的人选,但她没有从中找到喜欢的人。经过一番总结,她发现曾交往过三种类型的人:
尽管发现了上述规律,但海伦依然无法将约会网站推荐的匹配对象归入恰当的分类。她觉得可以在周一到周五约会哪些魅力一般的人,而周末则更喜欢与那些极具魅力的人为伴。海伦希望我们的分类软件可以更好地帮助她匹配对象划分到确切的分类中。此外海伦还收集了一些约会网站未曾记录的数据信息,她认为这些数据更有助于匹配对象的归类。
海伦收集约会数据已经有一段时间,她把这些数据存放在文本文件datingTestSet.txt中,每个样本占据一行,总共有1000行。海伦的样本主要包含以下3种特征:
以上是《机器学习实战》中介绍 K 最邻近算法给出的示例,通过该示例我们可以了解到 K 最邻近算法应用的另一个场景:改善约会网站的配对效果。本次介绍 K-D 树是为了加速寻找给定节点的邻居节点,是提升 K 最邻近算法执行效率的一种改进。
我们首先介绍一下 K-D 树的基本原理,由于篇幅的限制,详细的理论部分可以参见对应的维基百科。
K-D 树
K-D 树(K 维树的简称)是一种用于在 K 维空间中组织点的空间划分数据结构。K-D 树是非常有用的一种数据结构,常用于涉及多维搜索键的搜索(比如,范围搜索和邻近搜索)。
K-D 树是一个二叉树,其中每个节点都是一个 K 维点。每个非叶节点都可以被认为是隐式生成一个分裂超平面,将空间分成两部分,称为半空间。该超平面左侧的点表示该节点的左子树,而该超平面右侧的点则表示为右子树。超平面方向的选择方法如下:树中的每个节点都与一个 K 维相关联,超平面垂直于该维的轴。因此,例如,如果对特定的拆分选择了 X 轴,则子树中所有 X 值小于节点的点都将显示在左子树中,而所有 X 值较大的点都将显示在右子树中。在这种情况下,超平面将由点的 X 值设置,其法向量为单位 X 轴。
更详细的介绍,见维基百科:
https://en.wikipedia.org/wiki/K-d_tree
麦哈顿距离
更详细的介绍,见维基百科:
https://en.wikipedia.org/wiki/Taxicab_geometry
有了以上的基础,我们首先来定义 K-D 树节点的结构,然后定义邻居节点的结构,以及麦哈顿距离,最后给出 K-D 树的结构以及具体应用。
1. 定义K-D 树节点的结构
public class KDTreeNode
{
//分割轴
public int Axis { get; set; }
//节点值
public double[] Position { get; set; }
//标签值
public T Value { get; set; }
//左子树
public KDTreeNode Left { get; set; }
//右子树
public KDTreeNode Right { get; set; }
//是否为叶子节点
public bool IsLeaf
{
get { return Left == null && Right == null; }
}
}
2. 定义邻居节点的结构
public struct KDTreeNodeDistance : IComparable, IComparable>, IEquatable>
{
public double Distance { get; }
public KDTreeNode Node { get; }
public KDTreeNodeDistance(KDTreeNode node, double distance)
{
this.Node = node;
this.Distance = distance;
}
public int CompareTo(object obj)
{
return Distance.CompareTo((KDTreeNodeDistance)obj);
}
public int CompareTo(KDTreeNodeDistance other)
{
return Distance.CompareTo(other.Distance);
}
public bool Equals(KDTreeNodeDistance other)
{
return Distance == other.Distance && Node == other.Node;
}
}
3. 定义麦哈顿距离
public sealed class Manhattan : IMetric
{
public double Distance(double[] x, double[] y)
{
double sum = 0.0;
for (int i = 0; i < x.Length; i++)
sum += Math.Abs(x[i] - y[i]);
return sum;
}
}
4. 定义 K-D 树的结构
public class KDTree : IEnumerable>
{
//距离
public IMetric Distance { get; set; } = new Euclidean();
//树根
public KDTreeNode Root { get; }
//节点个数
public int Count { get; }
//叶子节点个数
public int Leaves { get; }
//节点维数
public int Dimensions { get; }
public KDTree(double[][] points, T[] values, IMetric distance, bool inPlace = false)
{
if (points == null)
throw new ArgumentNullException();
if (points.Length == 0)
throw new ArgumentException("创建树的点数不足。");
if (values == null)
throw new ArgumentNullException();
if (distance == null)
throw new ArgumentNullException();
int leaves;
Root = CreateRoot(points, values, inPlace, out leaves);
Leaves = leaves;
Distance = distance;
Dimensions = points[0].Length;
Count = points.Length;
}
public KDTree(double[][] points, bool inPlace = false)
{
if (points == null)
throw new ArgumentNullException();
if (points.Length == 0)
throw new ArgumentException("创建树的点数不足。");
int leaves;
Root = CreateRoot(points, null, inPlace, out leaves);
Leaves = leaves;
Dimensions = points[0].Length;
Count = points.Length;
}
protected KDTreeNode CreateRoot(double[][] points, T[] values, bool inPlace, out int leaves)
{
if (points == null)
throw new ArgumentNullException();
if (values != null && points.Length != values.Length)
throw new DimensionMismatchException("values");
if (!inPlace)
{
points = (double[][])points.Clone();
if (values != null)
values = (T[])values.Clone();
}
leaves = 0;
int dimensions = points[0].Length;
ElementComparer comparer = new ElementComparer();
KDTreeNode root = Create(points, values, 0, dimensions, 0, points.Length, comparer, ref leaves);
return root;
}
private KDTreeNode Create(double[][] points, T[] values,int depth, int k, int start, int length,
ElementComparer comparer, ref int leaves)
{
if (length <= 0)
return null;
int axis = comparer.Index = depth%k;
Array.Sort(points, values, start, length, comparer);
int half = start + length/2;
int leftStart = start;
int leftLength = half - start;
int rightStart = half + 1;
int rightLength = length - length/2 - 1;
double[] median = points[half];
T value = values != null ? values[half] : default(T);
depth++;
KDTreeNode left = Create(points, values, depth, k, leftStart, leftLength, comparer, ref leaves);
KDTreeNode right = Create(points, values, depth, k, rightStart, rightLength, comparer, ref leaves);
if (left == null && right == null)
leaves++;
return new KDTreeNode()
{
Axis = axis,
Position = median,
Value = value,
Left = left,
Right = right,
};
}
private void Nearest(KDTreeNode current, double[] position, KDTreeNodeCollection list)
{
double d = Distance.Distance(position, current.Position);
list.Add(current, d);
double value = position[current.Axis];
double median = current.Position[current.Axis];
double u = value - median;
if (u <= 0)
{
if (current.Left != null)
Nearest(current.Left, position, list);
if (current.Right != null && Math.Abs(u) <= list.Maximum)
Nearest(current.Right, position, list);
}
else
{
if (current.Right != null)
Nearest(current.Right, position, list);
if (current.Left != null && Math.Abs(u) <= list.Maximum)
Nearest(current.Left, position, list);
}
}
private void Nearest(KDTreeNode current, double[] position,double radius, ICollection> list)
{
double d = Distance.Distance(position, current.Position);
if (d <= radius)
list.Add(new KDTreeNodeDistance(current, d));
double value = position[current.Axis];
double median = current.Position[current.Axis];
double u = value - median;
if (u <= 0)
{
if (current.Left != null)
Nearest(current.Left, position, radius, list);
if (current.Right != null && Math.Abs(u) <= radius)
Nearest(current.Right, position, radius, list);
}
else
{
if (current.Right != null)
Nearest(current.Right, position, radius, list);
if (current.Left != null && Math.Abs(u) <= radius)
Nearest(current.Left, position, radius, list);
}
}
public KDTreeNodeCollection Nearest(double[] position, int neighbors)
{
KDTreeNodeCollection list = new KDTreeNodeCollection(neighbors);
if (Root != null)
Nearest(Root, position, list);
return list;
}
public KDTreeNodeList Nearest(double[] position, double radius)
{
KDTreeNodeList list = new KDTreeNodeList();
if (Root != null)
Nearest(Root, position, radius, list);
return list;
}
public IEnumerator> GetEnumerator()
{
if (Root == null)
yield break;
Stack> stack = new Stack>(new[] {Root});
while (stack.Count != 0)
{
KDTreeNode current = stack.Pop();
yield return current;
if (current.Left != null)
stack.Push(current.Left);
if (current.Right != null)
stack.Push(current.Right);
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
5. K-D 树的应用
假设我们拥有以下点的集合:
double[][] points =
{
new double[] {2, 3},
new double[] {5, 4},
new double[] {9, 6},
new double[] {4, 7},
new double[] {8, 1},
new double[] {7, 2},
};
从这些数据点中创建K-D 树:
KDTree tree = new KDTree(points);
我们可以手动导航这颗树:
KDTreeNode node = tree.Root.Left.Right;
或者自动遍历这棵树,由于 KDTree
实现了枚举器也即实现了迭代器模式,所以可以用 foreach
来遍历这棵树的数据。但对于普通的二叉树来说,通常使用前序遍历、中序遍历、后序遍历或者层次遍历的方式进行。
foreach (KDTreeNode n in tree)
{
double[] location = n.Position;
Console.WriteLine(@"({0},{1})", location[0], location[1]);
}
可以得到如下的结果:
(7,2)
(9,6)
(8,1)
(5,4)
(4,7)
(2,3)
给定一个查询点(例如:(5,3)
),我们还可以查询半径(欧氏距离 4.0)内靠近该点的其它点。
double[] query = new double[] {5, 3};
KDTreeNodeList result = tree.Nearest(query, 4.0);
for (int i = 0, len = result.Count; i < len; i++)
{
KDTreeNode node = result[i].Node;
Console.WriteLine(@"({0},{1})", node.Position[0], node.Position[1]);
}
可以得到如下的结果:
(7,2)
(5,4)
(2,3)
(8,1)
我们也可以使用其它的距离度量,比如麦哈顿距离:
double[] query = new double[] {5, 3};
tree.Distance = new Manhattan();
KDTreeNodeList result = tree.Nearest(query, 4.0);
for (int i = 0, len = result.Count; i < len; i++)
{
KDTreeNode node = result[i].Node;
Console.WriteLine(@"({0},{1})", node.Position[0], node.Position[1]);
}
可以得到如下的结果:
(7,2)
(5,4)
(2,3)
以及查询固定数量的相邻点,比如
KDTreeNodeCollection neighbors = tree.Nearest(query, 3);
for (int i = 0, len = neighbors.Count; i < len; i++)
{
KDTreeNode node = neighbors[i].Node;
Console.WriteLine(@"({0},{1})", node.Position[0], node.Position[1]);
}
可以得到如下的结果:
(5,4)
(2,3)
(7,2)
到此为止,用 C# 实现 K-D 树就介绍完了。在后台回复 20190312 可以得到,本篇开头说的 网站约会信息的数据集。大家把上面的代码看懂后,可以尝试的写一下,然后用这个数据集来测试自己的代码。
今天就到这里吧!See You!