数据挖掘--贝叶斯Bayes算法实现

注:本文转载于http://blog.csdn.net/luowen3405/article/details/6258818#

本算法的实现仅仅适用于小规模数据集的实验与测试,不适合用于工程应用

 

算法假定训练数据各属性列的值均是离散类型的。若是非离散类型的数据,需要首先进行数据的预处理,将非离散型的数据离散化。

 

算法中使用到了DecimalCaculate类,该类是java中BigDecimal类的扩展,用于高精度浮点数的运算。该类的实现同本人转载的一篇博文:对BigDecimal常用方法的归类中的Arith类相同。

 

算法实现的代码如下

[java]  view plain copy
  1. package Bayes;  
  2. import java.util.ArrayList;  
  3. import java.util.HashMap;  
  4. import java.util.Map;  
  5. import util.DecimalCalculate;  
  6. /** 
  7.  * 贝叶斯主体类 
  8.  * @author Rowen 
  9.  * @qq 443773264 
  10.  * @mail [email protected] 
  11.  * @blog blog.csdn.net/luowen3405 
  12.  * @data 2011.03.15 
  13.  */  
  14. public class Bayes {  
  15.     /** 
  16.      * 将原训练元组按类别划分 
  17.      * @param datas 训练元组 
  18.      * @return Map<类别,属于该类别的训练元组> 
  19.      */  
  20.     Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){  
  21.         Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>();  
  22.         ArrayList<String> t = null;  
  23.         String c = "";  
  24.         for (int i = 0; i < datas.size(); i++) {  
  25.             t = datas.get(i);  
  26.             c = t.get(t.size() - 1);  
  27.             if (map.containsKey(c)) {  
  28.                 map.get(c).add(t);  
  29.             } else {  
  30.                 ArrayList<ArrayList<String>> nt = new ArrayList<ArrayList<String>>();  
  31.                 nt.add(t);  
  32.                 map.put(c, nt);  
  33.             }  
  34.         }  
  35.         return map;  
  36.     }  
  37.       
  38.     /** 
  39.      * 在训练数据的基础上预测测试元组的类别 
  40.      * @param datas 训练元组 
  41.      * @param testT 测试元组 
  42.      * @return 测试元组的类别 
  43.      */  
  44.     public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) {  
  45.         Map<String, ArrayList<ArrayList<String>>> doc = this.datasOfClass(datas);  
  46.         Object classes[] = doc.keySet().toArray();  
  47.         double maxP = 0.00;  
  48.         int maxPIndex = -1;  
  49.         for (int i = 0; i < doc.size(); i++) {  
  50.             String c = classes[i].toString();   
  51.             ArrayList<ArrayList<String>> d = doc.get(c);  
  52.             double pOfC = DecimalCalculate.div(d.size(), datas.size(), 3);  
  53.             for (int j = 0; j < testT.size(); j++) {  
  54.                 double pv = this.pOfV(d, testT.get(j), j);  
  55.                 pOfC = DecimalCalculate.mul(pOfC, pv);  
  56.             }  
  57.             if(pOfC > maxP){  
  58.                 maxP = pOfC;  
  59.                 maxPIndex = i;  
  60.             }  
  61.         }  
  62.         return classes[maxPIndex].toString();  
  63.     }  
  64.     /** 
  65.      * 计算指定属性列上指定值出现的概率 
  66.      * @param d 属于某一类的训练元组 
  67.      * @param value 列值 
  68.      * @param index 属性列索引 
  69.      * @return 概率 
  70.      */  
  71.     private double pOfV(ArrayList<ArrayList<String>> d, String value, int index) {  
  72.         double p = 0.00;  
  73.         int count = 0;  
  74.         int total = d.size();  
  75.         ArrayList<String> t = null;  
  76.         for (int i = 0; i < total; i++) {  
  77.             if(d.get(i).get(index).equals(value)){  
  78.                 count++;  
  79.             }  
  80.         }  
  81.         p = DecimalCalculate.div(count, total, 3);  
  82.         return p;  
  83.     }  
  84. }  
 

 

算法测试类:

[java]  view plain copy
  1. package Bayes;  
  2. import java.io.BufferedReader;  
  3. import java.io.IOException;  
  4. import java.io.InputStreamReader;  
  5. import java.util.ArrayList;  
  6. import java.util.StringTokenizer;  
  7. /** 
  8.  * 贝叶斯算法测试类 
  9.  * @author Rowen 
  10.  * @qq 443773264 
  11.  * @mail [email protected] 
  12.  * @blog blog.csdn.net/luowen3405 
  13.  * @data 2011.03.15 
  14.  */  
  15. public class TestBayes {  
  16.     /** 
  17.      * 读取测试元组 
  18.      * @return 一条测试元组 
  19.      * @throws IOException 
  20.      */  
  21.     public ArrayList<String> readTestData() throws IOException{  
  22.         ArrayList<String> candAttr = new ArrayList<String>();  
  23.         BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));  
  24.         String str = "";  
  25.         while (!(str = reader.readLine()).equals("")) {  
  26.             StringTokenizer tokenizer = new StringTokenizer(str);  
  27.             while (tokenizer.hasMoreTokens()) {  
  28.                 candAttr.add(tokenizer.nextToken());  
  29.             }  
  30.         }  
  31.         return candAttr;  
  32.     }  
  33.       
  34.     /** 
  35.      * 读取训练元组 
  36.      * @return 训练元组集合 
  37.      * @throws IOException 
  38.      */  
  39.     public ArrayList<ArrayList<String>> readData() throws IOException {  
  40.         ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();  
  41.         BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));  
  42.         String str = "";  
  43.         while (!(str = reader.readLine()).equals("")) {  
  44.             StringTokenizer tokenizer = new StringTokenizer(str);  
  45.             ArrayList<String> s = new ArrayList<String>();  
  46.             while (tokenizer.hasMoreTokens()) {  
  47.                 s.add(tokenizer.nextToken());  
  48.             }  
  49.             datas.add(s);  
  50.         }  
  51.         return datas;  
  52.     }  
  53.       
  54.     public static void main(String[] args) {  
  55.         TestBayes tb = new TestBayes();  
  56.         ArrayList<ArrayList<String>> datas = null;  
  57.         ArrayList<String> testT = null;  
  58.         Bayes bayes = new Bayes();  
  59.         try {  
  60.             System.out.println("请输入训练数据");  
  61.             datas = tb.readData();  
  62.             while (true) {  
  63.                 System.out.println("请输入测试元组");  
  64.                 testT = tb.readTestData();  
  65.                 String c = bayes.predictClass(datas, testT);  
  66.                 System.out.println("The class is: " + c);  
  67.             }  
  68.         } catch (IOException e) {  
  69.             e.printStackTrace();  
  70.         }  
  71.     }  
  72. }  
 

 

训练数据:

[java]  view plain copy
  1. youth high no fair no  
  2. youth high no excellent no  
  3. middle_aged high no fair yes  
  4. senior medium no fair yes  
  5. senior low yes fair yes  
  6. senior low yes excellent no  
  7. middle_aged low yes excellent yes  
  8. youth medium no fair no  
  9. youth low yes fair yes  
  10. senior medium yes fair yes  
  11. youth medium yes excellent yes  
  12. middle_aged medium no excellent yes  
  13. middle_aged high yes fair yes  
  14. senior medium no excellent no  
 

 

对原训练数据进行测试,测试如果如下:

[c-sharp]  view plain copy
  1. 请输入测试元组  
  2. youth high no fair  
  3. The class is: no  
  4. 请输入测试元组  
  5. youth high no excellent  
  6. The class is: no  
  7. 请输入测试元组  
  8. middle_aged high no fair  
  9. The class is: yes  
  10. 请输入测试元组  
  11. senior medium no fair  
  12. The class is: yes  
  13. 请输入测试元组  
  14. senior low yes fair  
  15. The class is: yes  
  16. 请输入测试元组  
  17. senior low yes excellent  
  18. The class is: yes  
  19. 请输入测试元组  
  20. middle_aged low yes excellent  
  21. The class is: yes  
  22. 请输入测试元组  
  23. youth medium no fair  
  24. The class is: no  
  25. 请输入测试元组  
  26. youth low yes fair  
  27. The class is: yes  
  28. 请输入测试元组  
  29. senior medium yes fair  
  30. The class is: yes  
  31. 请输入测试元组  
  32. youth medium yes excellent  
  33. The class is: yes  
  34. 请输入测试元组  
  35. middle_aged medium no excellent  
  36. The class is: yes  
  37. 请输入测试元组  
  38. middle_aged high yes fair  
  39. The class is: yes  
  40. 请输入测试元组  
  41. senior medium no excellent  
  42. The class is: no  
 

 

测试结果显示14个测试实例中有13个分类是正确的,正确率为93%,说明算法能够给出一个准确的预测与分类,但是算法还需改进以提高正确率。

 

改进的可选方法之一:

为避免单个属性值对分类结果的权重过大,例如当某属性值在某一类中出现0次时,该属性值就决定了测试实例已经不可能属于该类了,这就可能会造成误差,因此在计算概率时可能进行如下改进:

 

将原先的P(Xk|Ci)=|Xk| / |Ci| 改为 P(Xk|Ci)=(|Xk|+mp) / (|Ci|+m),其中m可设定为训练元组的个数,p为等可能假设的先验概率。

你可能感兴趣的:(java,算法,数据挖掘,贝叶斯,Bayes)