KNN最近邻基于欧几里德距离的JAVA算法实现

本算法只适合学习使用,可以大致了解一下KNN算法的原理。

算法作了如下的假定与简化处理:

1.小规模数据集

2.假设所有数据及类别都是数值类型的

3.直接根据数据规模设定了k值

4.对原训练集进行测试

KNN实现代码如下:

[java] view plain copy print ?
  1. package KNN; 
  2. /**
  3. * KNN结点类,用来存储最近邻的k个元组相关的信息
  4. * @author Rowen
  5. * @qq 443773264
  6. * @mail [email protected]
  7. * @blog blog.csdn.net/luowen3405
  8. * @data 2011.03.25
  9. */ 
  10. public class KNNNode { 
  11.     private int index; // 元组标号 
  12.     private double distance; // 与测试元组的距离 
  13.     private String c; // 所属类别 
  14.     public KNNNode(int index, double distance, String c) { 
  15.         super(); 
  16.         this.index = index; 
  17.         this.distance = distance; 
  18.         this.c = c; 
  19.     } 
  20.      
  21.      
  22.     public int getIndex() { 
  23.         return index; 
  24.     } 
  25.     public void setIndex(int index) { 
  26.         this.index = index; 
  27.     } 
  28.     public double getDistance() { 
  29.         return distance; 
  30.     } 
  31.     public void setDistance(double distance) { 
  32.         this.distance = distance; 
  33.     } 
  34.     public String getC() { 
  35.         return c; 
  36.     } 
  37.     public void setC(String c) { 
  38.         this.c = c; 
  39.     } 

package KNN; /** * KNN结点类,用来存储最近邻的k个元组相关的信息 * @author Rowen * @qq 443773264 * @mail [email protected] * @blog blog.csdn.net/luowen3405 * @data 2011.03.25 */ public class KNNNode { private int index; // 元组标号 private double distance; // 与测试元组的距离 private String c; // 所属类别 public KNNNode(int index, double distance, String c) { super(); this.index = index; this.distance = distance; this.c = c; } public int getIndex() { return index; } public void setIndex(int index) { this.index = index; } public double getDistance() { return distance; } public void setDistance(double distance) { this.distance = distance; } public String getC() { return c; } public void setC(String c) { this.c = c; } }

[java] view plain copy print ?
  1. package KNN; 
  2. import java.util.ArrayList; 
  3. import java.util.Comparator; 
  4. import java.util.HashMap; 
  5. import java.util.List; 
  6. import java.util.Map; 
  7. import java.util.PriorityQueue; 
  8.  
  9. /**
  10. * KNN算法主体类
  11. * @author Rowen
  12. * @qq 443773264
  13. * @mail [email protected]
  14. * @blog blog.csdn.net/luowen3405
  15. * @data 2011.03.25
  16. */ 
  17. public class KNN { 
  18.     /**
  19.      * 设置优先级队列的比较函数,距离越大,优先级越高
  20.      */ 
  21.     private Comparator<KNNNode> comparator = new Comparator<KNNNode>() { 
  22.         public int compare(KNNNode o1, KNNNode o2) { 
  23.             if (o1.getDistance() >= o2.getDistance()) { 
  24.                 return 1
  25.             } else
  26.                 return 0
  27.             } 
  28.         } 
  29.     }; 
  30.     /**
  31.      * 获取K个不同的随机数
  32.      * @param k 随机数的个数
  33.      * @param max 随机数最大的范围
  34.      * @return 生成的随机数数组
  35.      */ 
  36.     public List<Integer> getRandKNum(int k, int max) { 
  37.         List<Integer> rand = new ArrayList<Integer>(k); 
  38.         for (int i = 0; i < k; i++) { 
  39.             int temp = (int) (Math.random() * max); 
  40.             if (!rand.contains(temp)) { 
  41.                 rand.add(temp); 
  42.             } else
  43.                 i--; 
  44.             } 
  45.         } 
  46.         return rand; 
  47.     } 
  48.     /**
  49.      * 计算测试元组与训练元组之前的距离
  50.      * @param d1 测试元组
  51.      * @param d2 训练元组
  52.      * @return 距离值
  53.      */ 
  54.     public double calDistance(List<Double> d1, List<Double> d2) { 
  55.         double distance = 0.00
  56.         for (int i = 0; i < d1.size(); i++) { 
  57.             distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i)); 
  58.         } 
  59.         return distance; 
  60.     } 
  61.     /**
  62.      * 执行KNN算法,获取测试元组的类别
  63.      * @param datas 训练数据集
  64.      * @param testData 测试元组
  65.      * @param k 设定的K值
  66.      * @return 测试元组的类别
  67.      */ 
  68.     public String knn(List<List<Double>> datas, List<Double> testData, int k) { 
  69.         PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator); 
  70.         List<Integer> randNum = getRandKNum(k, datas.size()); 
  71.         for (int i = 0; i < k; i++) { 
  72.             int index = randNum.get(i); 
  73.             List<Double> currData = datas.get(index); 
  74.             String c = currData.get(currData.size() - 1).toString(); 
  75.             KNNNode node = new KNNNode(index, calDistance(testData, currData), c); 
  76.             pq.add(node); 
  77.         } 
  78.         for (int i = 0; i < datas.size(); i++) { 
  79.             List<Double> t = datas.get(i); 
  80.             double distance = calDistance(testData, t); 
  81.             KNNNode top = pq.peek(); 
  82.             if (top.getDistance() > distance) { 
  83.                 pq.remove(); 
  84.                 pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString())); 
  85.             } 
  86.         } 
  87.          
  88.         return getMostClass(pq); 
  89.     } 
  90.     /**
  91.      * 获取所得到的k个最近邻元组的多数类
  92.      * @param pq 存储k个最近近邻元组的优先级队列
  93.      * @return 多数类的名称
  94.      */ 
  95.     private String getMostClass(PriorityQueue<KNNNode> pq) { 
  96.         Map<String, Integer> classCount = new HashMap<String, Integer>(); 
  97.         for (int i = 0; i < pq.size(); i++) { 
  98.             KNNNode node = pq.remove(); 
  99.             String c = node.getC(); 
  100.             if (classCount.containsKey(c)) { 
  101.                 classCount.put(c, classCount.get(c) + 1); 
  102.             } else
  103.                 classCount.put(c, 1); 
  104.             } 
  105.         } 
  106.         int maxIndex = -1
  107.         int maxCount = 0
  108.         Object[] classes = classCount.keySet().toArray(); 
  109.         for (int i = 0; i < classes.length; i++) { 
  110.             if (classCount.get(classes[i]) > maxCount) { 
  111.                 maxIndex = i; 
  112.                 maxCount = classCount.get(classes[i]); 
  113.             } 
  114.         } 
  115.         return classes[maxIndex].toString(); 
  116.     } 

package KNN; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; /** * KNN算法主体类 * @author Rowen * @qq 443773264 * @mail [email protected] * @blog blog.csdn.net/luowen3405 * @data 2011.03.25 */ public class KNN { /** * 设置优先级队列的比较函数,距离越大,优先级越高 */ private Comparator<KNNNode> comparator = new Comparator<KNNNode>() { public int compare(KNNNode o1, KNNNode o2) { if (o1.getDistance() >= o2.getDistance()) { return 1; } else { return 0; } } }; /** * 获取K个不同的随机数 * @param k 随机数的个数 * @param max 随机数最大的范围 * @return 生成的随机数数组 */ public List<Integer> getRandKNum(int k, int max) { List<Integer> rand = new ArrayList<Integer>(k); for (int i = 0; i < k; i++) { int temp = (int) (Math.random() * max); if (!rand.contains(temp)) { rand.add(temp); } else { i--; } } return rand; } /** * 计算测试元组与训练元组之前的距离 * @param d1 测试元组 * @param d2 训练元组 * @return 距离值 */ public double calDistance(List<Double> d1, List<Double> d2) { double distance = 0.00; for (int i = 0; i < d1.size(); i++) { distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i)); } return distance; } /** * 执行KNN算法,获取测试元组的类别 * @param datas 训练数据集 * @param testData 测试元组 * @param k 设定的K值 * @return 测试元组的类别 */ public String knn(List<List<Double>> datas, List<Double> testData, int k) { PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator); List<Integer> randNum = getRandKNum(k, datas.size()); for (int i = 0; i < k; i++) { int index = randNum.get(i); List<Double> currData = datas.get(index); String c = currData.get(currData.size() - 1).toString(); KNNNode node = new KNNNode(index, calDistance(testData, currData), c); pq.add(node); } for (int i = 0; i < datas.size(); i++) { List<Double> t = datas.get(i); double distance = calDistance(testData, t); KNNNode top = pq.peek(); if (top.getDistance() > distance) { pq.remove(); pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString())); } } return getMostClass(pq); } /** * 获取所得到的k个最近邻元组的多数类 * @param pq 存储k个最近近邻元组的优先级队列 * @return 多数类的名称 */ private String getMostClass(PriorityQueue<KNNNode> pq) { Map<String, Integer> classCount = new HashMap<String, Integer>(); for (int i = 0; i < pq.size(); i++) { KNNNode node = pq.remove(); String c = node.getC(); if (classCount.containsKey(c)) { classCount.put(c, classCount.get(c) + 1); } else { classCount.put(c, 1); } } int maxIndex = -1; int maxCount = 0; Object[] classes = classCount.keySet().toArray(); for (int i = 0; i < classes.length; i++) { if (classCount.get(classes[i]) > maxCount) { maxIndex = i; maxCount = classCount.get(classes[i]); } } return classes[maxIndex].toString(); } }

[java] view plain copy print ?
  1. package KNN; 
  2. import java.io.BufferedReader; 
  3. import java.io.File; 
  4. import java.io.FileReader; 
  5. import java.util.ArrayList; 
  6. import java.util.List; 
  7. /**
  8. * KNN算法测试类
  9. * @author Rowen
  10. * @qq 443773264
  11. * @mail [email protected]
  12. * @blog blog.csdn.net/luowen3405
  13. * @data 2011.03.25
  14. */ 
  15. public class TestKNN { 
  16.      
  17.     /**
  18.      * 从数据文件中读取数据
  19.      * @param datas 存储数据的集合对象
  20.      * @param path 数据文件的路径
  21.      */ 
  22.     public void read(List<List<Double>> datas, String path){ 
  23.         try
  24.             BufferedReader br = new BufferedReader(new FileReader(new File(path))); 
  25.             String data = br.readLine(); 
  26.             List<Double> l = null
  27.             while (data != null) { 
  28.                 String t[] = data.split(" "); 
  29.                 l = new ArrayList<Double>(); 
  30.                 for (int i = 0; i < t.length; i++) { 
  31.                     l.add(Double.parseDouble(t[i])); 
  32.                 } 
  33.                 datas.add(l); 
  34.                 data = br.readLine(); 
  35.             } 
  36.         } catch (Exception e) { 
  37.             e.printStackTrace(); 
  38.         } 
  39.     } 
  40.      
  41.     /**
  42.      * 程序执行入口
  43.      * @param args
  44.      */ 
  45.     public static void main(String[] args) { 
  46.         TestKNN t = new TestKNN(); 
  47.         String datafile = new File("").getAbsolutePath() + File.separator + "datafile"
  48.         String testfile = new File("").getAbsolutePath() + File.separator + "testfile"
  49.         try
  50.             List<List<Double>> datas = new ArrayList<List<Double>>(); 
  51.             List<List<Double>> testDatas = new ArrayList<List<Double>>(); 
  52.             t.read(datas, datafile); 
  53.             t.read(testDatas, testfile); 
  54.             KNN knn = new KNN(); 
  55.             for (int i = 0; i < testDatas.size(); i++) { 
  56.                 List<Double> test = testDatas.get(i); 
  57.                 System.out.print("测试元组: "); 
  58.                 for (int j = 0; j < test.size(); j++) { 
  59.                     System.out.print(test.get(j) + " "); 
  60.                 } 
  61.                 System.out.print("类别为: "); 
  62.                 System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3))))); 
  63.             } 
  64.         } catch (Exception e) { 
  65.             e.printStackTrace(); 
  66.         } 
  67.     } 

package KNN; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.ArrayList; import java.util.List; /** * KNN算法测试类 * @author Rowen * @qq 443773264 * @mail [email protected] * @blog blog.csdn.net/luowen3405 * @data 2011.03.25 */ public class TestKNN { /** * 从数据文件中读取数据 * @param datas 存储数据的集合对象 * @param path 数据文件的路径 */ public void read(List<List<Double>> datas, String path){ try { BufferedReader br = new BufferedReader(new FileReader(new File(path))); String data = br.readLine(); List<Double> l = null; while (data != null) { String t[] = data.split(" "); l = new ArrayList<Double>(); for (int i = 0; i < t.length; i++) { l.add(Double.parseDouble(t[i])); } datas.add(l); data = br.readLine(); } } catch (Exception e) { e.printStackTrace(); } } /** * 程序执行入口 * @param args */ public static void main(String[] args) { TestKNN t = new TestKNN(); String datafile = new File("").getAbsolutePath() + File.separator + "datafile"; String testfile = new File("").getAbsolutePath() + File.separator + "testfile"; try { List<List<Double>> datas = new ArrayList<List<Double>>(); List<List<Double>> testDatas = new ArrayList<List<Double>>(); t.read(datas, datafile); t.read(testDatas, testfile); KNN knn = new KNN(); for (int i = 0; i < testDatas.size(); i++) { List<Double> test = testDatas.get(i); System.out.print("测试元组: "); for (int j = 0; j < test.size(); j++) { System.out.print(test.get(j) + " "); } System.out.print("类别为: "); System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3))))); } } catch (Exception e) { e.printStackTrace(); } } }

训练数据文件:

[java] view plain copy print ?
  1. 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1 
  2. 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1 
  3. 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1 
  4. 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0 
  5. 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1 
  6. 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0 

1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0

[java] view plain copy print ?
  1. 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 
  2. 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 
  3. 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 
  4. 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 
  5. 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 
  6. 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 

1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5

程序运行结果:

[java] view plain copy print ?
  1. 测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1 
  2. 测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1 
  3. 测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1 
  4. 测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0 
  5. 测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1 
  6. 测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0 

 

你可能感兴趣的:(java,exception,String,测试,import,distance)