C#实现KNN算法

KNN算法的C#代码,上一篇博客中的 C#创建KD树的程序中的算法是模仿MATLAB的KDTree的程序思路
这次按照李航老师的《统计学习方法》中的思路,写一个C#程序,其中创建KD树的分割的维度并不是轮寻,而是按照数据的范围来找的
using System;
using System.Collections.Generic;
using System.Linq;


namespace KNNSearch
{
    /// 
    /// Description of KNN.
    /// 
    public class Knn
    {
        /// 
        /// 叶子节点点的个数
        /// 
        private int leafnum = 1;
        /// 
        /// 节点名称集合
        /// 
        private List _nodeNames = new List
        {
            "A",
            "B",
            "C",
            "D",
            "E",
            "F",
            "G",
            "H",
            "I",
            "J",
            "K",
            "L",
            "M",
            "N",
            "O",
            "P",
            "Q",
            "R",
            "S",
            "T",
            "U",
            "V",
            "W",
            "X",
            "Y",
            "Z"
        };
        private List GeneralRawData(int num)
        {
            List rawData = new List();
            Random r = new Random(1);
            for (var i = 0; i < num; i++)
            {
                rawData.Add(new Point() { X = r.NextDouble(), Y = r.NextDouble(), Z = r.NextDouble(), ID = i });
            }
            //PrintListData(rawData);
            return rawData;
        }

        /// 
        /// 创建KD树
        /// 
        /// 
        /// 
        private Node CreateKdTree(List data)
        {
            // 创建根节点
            Node root = new Node {NodeData = data};
            // 添加当前节点数据
            // 如果节点的数据数量小于叶子节点的数量限制,则当前节点为叶子节点
            if (data.Count <= leafnum)
            {
                if (data.Count == 0)
                {
                    return null;
                }
                root.LeftNode = null;
                root.RightNode = null;
                root.Point = data[0];
                root.Splitaxis = -1;
                root.Name = "AA";
                //_nodeNames.RemoveAt(0);
                //Console.WriteLine("叶子节点编号{0}, 数据点编号{1}",root.Name, root.NodeData[0].ID);
                return root;
            }
            // 找到分割轴
            int splitAxis = GetSplitAxis(data);
            // 分割数据
            Tuple, List> dataSplit = GetSplitNum(data, splitAxis);
            root.Splitaxis = splitAxis;
            root.Point = dataSplit.Item1;
            root.Name = "AA";
            //_nodeNames.RemoveAt(0);
            root.LeftNode = CreateKdTree(dataSplit.Item2);
            root.RightNode = CreateKdTree(dataSplit.Item3);
            return root;
        }

        private Tuple, List> GetSplitNum(List data, int splitAxis)
        {
            // 对数据按照第splitAxis排序
            var data0 = data.OrderBy(x => Dict[splitAxis](x)).ToList();
            int half = data0.Count / 2;
            List leftdata = new List();
            List rightdata = new List();
            for (int i = 0; i < data0.Count; i++)
            {
                if (i < half)
                {
                    leftdata.Add(data0[i]);
                }
                else if (i > half)
                {
                    rightdata.Add(data0[i]);
                }
            }
            //Console.WriteLine("Split Axis: {0}", splitAxis);
            //PrintListData(data0);
            return new Tuple, List>(data0[half], leftdata, rightdata);
        }
        /// 
        /// 获取分割轴编号
        /// 
        /// 
        /// 
        private int GetSplitAxis(List data)
        {
            // 设定数据范围最大的轴作为分割轴(也有其他的方式,如方差,或者轮流的方式)
            List ranges = new List();
            for (int i = 0; i < 3; i++)
            {
                var i1 = i;
                var xxxData = data.Select(item => Dict[i1](item));
                var enumerable = xxxData as double[] ?? xxxData.ToArray();
                ranges.Add(enumerable.Max() - enumerable.Min());
            }
            var sorted = ranges.Select((x, i) => new KeyValuePair(x, i)).OrderByDescending(x => x.Key).ToList();
            return sorted.Select(x => x.Value).ToList()[0];
        }

        /// 
        /// KNN搜索
        /// 
        /// 
        /// 
        /// 
        private Node KdTreeFindNearest(Node tree, Point target)
        {
            // 搜索路径
            List searchPath = new List();
            // 当前搜索点
            Node searchNode = tree;
            //(1) 从根节点开始往下搜索, 递归的向下访问KD树
            while (searchNode != null)
            {
                // 添加当前节点到搜索路径
                searchPath.Add(searchNode);
                var splitAxis = searchNode.Splitaxis;
                // 若目标点当前维小于节点的阈值,移动至左叶子点,否则移动至右叶子点
                searchNode = splitAxis < 0 ? null : Dict[splitAxis](target) <= Dict[splitAxis](searchNode.Point) ? searchNode.LeftNode : searchNode.RightNode;
            }
            // (2) 以此节点为当前最近节点
            // 最近的点
            Node nearestPoint = searchPath[searchPath.Count - 1];
            // 初值最短距离
            double dist = NearestDist(nearestPoint.NodeData, target);
            // 移除当前点
            searchPath.Remove(nearestPoint);
            // (3). 递归向上回退
            while (searchPath.Count > 0)
            {
                var backNode = searchPath[searchPath.Count - 1]; // 回退节点
                //(a)如果该节点保存的实例点距离目标点的距离比当前最近点更近, 则该点设置为当前最近点
                if (dist > NearestDist(backNode.NodeData, target))
                {
                    nearestPoint = backNode;
                    dist = NearestDist(backNode.NodeData, target);
                    // 如果更近,说明必然在其子节点中
                    var splitaxis = backNode.Splitaxis;
                    // 目标点据当前分割边界的距离

                    var distTargetToBound = Math.Abs(Dict[splitaxis](target) - Dict[splitaxis](backNode.Point));
                    // 如果以最近距离为半径,另外一个子节点位于球的内部,说明最近点位于另外一个叶子节点
                    // 移动至另外一个节点
                    if (distTargetToBound < dist)
                    {
                        // 当前点位于位于该节点的左子节点,需要进入另外一个节点搜索
                        searchNode = Dict[splitaxis](target) < Dict[splitaxis](backNode.Point) ? backNode.RightNode : backNode.LeftNode;
                        searchPath.Add(searchNode);
                    }
                }
                searchPath.Remove(backNode);

            }
            return nearestPoint;
        }

        private static Dictionary> Dict => new Dictionary>
        {
            { 0, p => p.X },
            { 1, p => p.Y },
            { 2, p => p.Z },

        };

        public List NodeNames { get => _nodeNames; set => _nodeNames = value; }

        /// 
        /// 计算当前结点实例点距目标点的最近距离
        /// 
        /// 
        /// 
        /// 
        private double NearestDist(List nodeData, Point target)
        {
            List ss = nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
                                                                Math.Pow(item.Y - target.Y, 2) +
                                                                Math.Pow(item.Z - target.Z, 2)))
                .ToList();
            return nodeData.Select(item => Math.Sqrt(Math.Pow(item.X - target.X, 2) +
            Math.Pow(item.Y - target.Y, 2) + Math.Pow(item.Z - target.Z, 2))).ToList().Min();
            
        }

        private void PrintListData(List data)
        {
            Console.WriteLine("****************");
            foreach (Point point in data)
            {
                Console.WriteLine(point);
            }
        }
        public Knn()
        {
            List rawData = GeneralRawData(180);
            Node node = CreateKdTree(rawData);
            Point target = new Point() {X = 0.5, Y = 0.5, Z = 0.5};
            Node nd = KdTreeFindNearest(node, target);
            // 最短距离为
            double nearestDistFromKnn = NearestDist(nd.NodeData, target);
            Console.WriteLine("通过KNN搜索计算得到的最短距离为{0:F3}", nearestDistFromKnn);
            double nearestDistFromLoop = NearestDist(rawData, target);
            Console.WriteLine("通过KNN遍历计算得到的最短距离为{0:F3}", nearestDistFromLoop);
        }
    }

    /// 
    /// Description of Node.
    /// 
    public class Node
    {
        /// 
        /// 节点名称
        /// 
        public string Name;
        /// 
        /// 切分的阈值点
        /// 
        public Point Point;
        /// 
        /// 左节点
        /// 
        public Node LeftNode;
        /// 
        /// 右节点
        /// 
        public Node RightNode;
        /// 
        /// 节点包含的数据
        /// 
        public List NodeData;
        /// 
        /// 分割轴
        /// 
        public int Splitaxis;
    }

    public class Point
    {
        public double X;
        public double Y;
        public double Z;
        public int ID; // debug用
        public override string ToString()
        {
            return $"({X},{Y},{Z},{ID})";
        }
    }
}

你可能感兴趣的:(C#,机器学习)