注:本算法的实现仅仅适用于小规模数据集的实验与测试,不适合用于工程应用
算法假定训练数据各属性列的值均是离散类型的。若是非离散类型的数据,需要首先进行数据的预处理,将非离散型的数据离散化。
算法中使用到了DecimalCaculate类,该类是java中BigDecimal类的扩展,用于高精度浮点数的运算。该类的实现同本人转载的一篇博文:对BigDecimal常用方法的归类中的Arith类相同。
算法实现的代码如下
[java] view plain copy print ?
- package Bayes;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.Map;
- import util.DecimalCalculate;
-
-
-
-
-
-
-
-
- public class Bayes {
-
-
-
-
-
- Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){
- Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>();
- ArrayList<String> t = null;
- String c = "";
- for (int i = 0; i < datas.size(); i++) {
- t = datas.get(i);
- c = t.get(t.size() - 1);
- if (map.containsKey(c)) {
- map.get(c).add(t);
- } else {
- ArrayList<ArrayList<String>> nt = new ArrayList<ArrayList<String>>();
- nt.add(t);
- map.put(c, nt);
- }
- }
- return map;
- }
-
-
-
-
-
-
-
- public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) {
- Map<String, ArrayList<ArrayList<String>>> doc = this.datasOfClass(datas);
- Object classes[] = doc.keySet().toArray();
- double maxP = 0.00;
- int maxPIndex = -1;
- for (int i = 0; i < doc.size(); i++) {
- String c = classes[i].toString();
- ArrayList<ArrayList<String>> d = doc.get(c);
- double pOfC = DecimalCalculate.div(d.size(), datas.size(), 3);
- for (int j = 0; j < testT.size(); j++) {
- double pv = this.pOfV(d, testT.get(j), j);
- pOfC = DecimalCalculate.mul(pOfC, pv);
- }
- if(pOfC > maxP){
- maxP = pOfC;
- maxPIndex = i;
- }
- }
- return classes[maxPIndex].toString();
- }
-
-
-
-
-
-
-
- private double pOfV(ArrayList<ArrayList<String>> d, String value, int index) {
- double p = 0.00;
- int count = 0;
- int total = d.size();
- ArrayList<String> t = null;
- for (int i = 0; i < total; i++) {
- if(d.get(i).get(index).equals(value)){
- count++;
- }
- }
- p = DecimalCalculate.div(count, total, 3);
- return p;
- }
- }
package Bayes; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import util.DecimalCalculate; /** * 贝叶斯主体类 * @author Rowen * @qq 443773264 * @mail
[email protected] * @blog blog.csdn.net/luowen3405 * @data 2011.03.15 */ public class Bayes { /** * 将原训练元组按类别划分 * @param datas 训练元组 * @return Map<类别,属于该类别的训练元组> */ Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){ Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>(); ArrayList<String> t = null; String c = ""; for (int i = 0; i < datas.size(); i++) { t = datas.get(i); c = t.get(t.size() - 1); if (map.containsKey(c)) { map.get(c).add(t); } else { ArrayList<ArrayList<String>> nt = new ArrayList<ArrayList<String>>(); nt.add(t); map.put(c, nt); } } return map; } /** * 在训练数据的基础上预测测试元组的类别 * @param datas 训练元组 * @param testT 测试元组 * @return 测试元组的类别 */ public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) { Map<String, ArrayList<ArrayList<String>>> doc = this.datasOfClass(datas); Object classes[] = doc.keySet().toArray(); double maxP = 0.00; int maxPIndex = -1; for (int i = 0; i < doc.size(); i++) { String c = classes[i].toString(); ArrayList<ArrayList<String>> d = doc.get(c); double pOfC = DecimalCalculate.div(d.size(), datas.size(), 3); for (int j = 0; j < testT.size(); j++) { double pv = this.pOfV(d, testT.get(j), j); pOfC = DecimalCalculate.mul(pOfC, pv); } if(pOfC > maxP){ maxP = pOfC; maxPIndex = i; } } return classes[maxPIndex].toString(); } /** * 计算指定属性列上指定值出现的概率 * @param d 属于某一类的训练元组 * @param value 列值 * @param index 属性列索引 * @return 概率 */ private double pOfV(ArrayList<ArrayList<String>> d, String value, int index) { double p = 0.00; int count = 0; int total = d.size(); ArrayList<String> t = null; for (int i = 0; i < total; i++) { if(d.get(i).get(index).equals(value)){ count++; } } p = DecimalCalculate.div(count, total, 3); return p; } }
算法测试类:
[java] view plain copy print ?
- package Bayes;
- import java.io.BufferedReader;
- import java.io.IOException;
- import java.io.InputStreamReader;
- import java.util.ArrayList;
- import java.util.StringTokenizer;
-
-
-
-
-
-
-
-
- public class TestBayes {
-
-
-
-
-
- public ArrayList<String> readTestData() throws IOException{
- ArrayList<String> candAttr = new ArrayList<String>();
- BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
- String str = "";
- while (!(str = reader.readLine()).equals("")) {
- StringTokenizer tokenizer = new StringTokenizer(str);
- while (tokenizer.hasMoreTokens()) {
- candAttr.add(tokenizer.nextToken());
- }
- }
- return candAttr;
- }
-
-
-
-
-
-
- public ArrayList<ArrayList<String>> readData() throws IOException {
- ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>();
- BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
- String str = "";
- while (!(str = reader.readLine()).equals("")) {
- StringTokenizer tokenizer = new StringTokenizer(str);
- ArrayList<String> s = new ArrayList<String>();
- while (tokenizer.hasMoreTokens()) {
- s.add(tokenizer.nextToken());
- }
- datas.add(s);
- }
- return datas;
- }
-
- public static void main(String[] args) {
- TestBayes tb = new TestBayes();
- ArrayList<ArrayList<String>> datas = null;
- ArrayList<String> testT = null;
- Bayes bayes = new Bayes();
- try {
- System.out.println("请输入训练数据");
- datas = tb.readData();
- while (true) {
- System.out.println("请输入测试元组");
- testT = tb.readTestData();
- String c = bayes.predictClass(datas, testT);
- System.out.println("The class is: " + c);
- }
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- }
package Bayes; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.StringTokenizer; /** * 贝叶斯算法测试类 * @author Rowen * @qq 443773264 * @mail
[email protected] * @blog blog.csdn.net/luowen3405 * @data 2011.03.15 */ public class TestBayes { /** * 读取测试元组 * @return 一条测试元组 * @throws IOException */ public ArrayList<String> readTestData() throws IOException{ ArrayList<String> candAttr = new ArrayList<String>(); BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); String str = ""; while (!(str = reader.readLine()).equals("")) { StringTokenizer tokenizer = new StringTokenizer(str); while (tokenizer.hasMoreTokens()) { candAttr.add(tokenizer.nextToken()); } } return candAttr; } /** * 读取训练元组 * @return 训练元组集合 * @throws IOException */ public ArrayList<ArrayList<String>> readData() throws IOException { ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>(); BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); String str = ""; while (!(str = reader.readLine()).equals("")) { StringTokenizer tokenizer = new StringTokenizer(str); ArrayList<String> s = new ArrayList<String>(); while (tokenizer.hasMoreTokens()) { s.add(tokenizer.nextToken()); } datas.add(s); } return datas; } public static void main(String[] args) { TestBayes tb = new TestBayes(); ArrayList<ArrayList<String>> datas = null; ArrayList<String> testT = null; Bayes bayes = new Bayes(); try { System.out.println("请输入训练数据"); datas = tb.readData(); while (true) { System.out.println("请输入测试元组"); testT = tb.readTestData(); String c = bayes.predictClass(datas, testT); System.out.println("The class is: " + c); } } catch (IOException e) { e.printStackTrace(); } } }
训练数据:
[java] view plain copy print ?
- youth high no fair no
- youth high no excellent no
- middle_aged high no fair yes
- senior medium no fair yes
- senior low yes fair yes
- senior low yes excellent no
- middle_aged low yes excellent yes
- youth medium no fair no
- youth low yes fair yes
- senior medium yes fair yes
- youth medium yes excellent yes
- middle_aged medium no excellent yes
- middle_aged high yes fair yes
- senior medium no excellent no
youth high no fair no youth high no excellent no middle_aged high no fair yes senior medium no fair yes senior low yes fair yes senior low yes excellent no middle_aged low yes excellent yes youth medium no fair no youth low yes fair yes senior medium yes fair yes youth medium yes excellent yes middle_aged medium no excellent yes middle_aged high yes fair yes senior medium no excellent no
对原训练数据进行测试,测试如果如下:
[c-sharp] view plain copy print ?
- 请输入测试元组
- youth high no fair
- The class is: no
- 请输入测试元组
- youth high no excellent
- The class is: no
- 请输入测试元组
- middle_aged high no fair
- The class is: yes
- 请输入测试元组
- senior medium no fair
- The class is: yes
- 请输入测试元组
- senior low yes fair
- The class is: yes
- 请输入测试元组
- senior low yes excellent
- The class is: yes
- 请输入测试元组
- middle_aged low yes excellent
- The class is: yes
- 请输入测试元组
- youth medium no fair
- The class is: no
- 请输入测试元组
- youth low yes fair
- The class is: yes
- 请输入测试元组
- senior medium yes fair
- The class is: yes
- 请输入测试元组
- youth medium yes excellent
- The class is: yes
- 请输入测试元组
- middle_aged medium no excellent
- The class is: yes
- 请输入测试元组
- middle_aged high yes fair
- The class is: yes
- 请输入测试元组
- senior medium no excellent
- The class is: no
请输入测试元组 youth high no fair The class is: no 请输入测试元组 youth high no excellent The class is: no 请输入测试元组 middle_aged high no fair The class is: yes 请输入测试元组 senior medium no fair The class is: yes 请输入测试元组 senior low yes fair The class is: yes 请输入测试元组 senior low yes excellent The class is: yes 请输入测试元组 middle_aged low yes excellent The class is: yes 请输入测试元组 youth medium no fair The class is: no 请输入测试元组 youth low yes fair The class is: yes 请输入测试元组 senior medium yes fair The class is: yes 请输入测试元组 youth medium yes excellent The class is: yes 请输入测试元组 middle_aged medium no excellent The class is: yes 请输入测试元组 middle_aged high yes fair The class is: yes 请输入测试元组 senior medium no excellent The class is: no
测试结果显示14个测试实例中有13个分类是正确的,正确率为93%,说明算法能够给出一个准确的预测与分类,但是算法还需改进以提高正确率。
改进的可选方法之一:
为避免单个属性值对分类结果的权重过大,例如当某属性值在某一类中出现0次时,该属性值就决定了测试实例已经不可能属于该类了,这就可能会造成误差,因此在计算概率时可能进行如下改进:
将原先的P(Xk|Ci)=|Xk| / |Ci| 改为 P(Xk|Ci)=(|Xk|+mp) / (|Ci|+m),其中m可设定为训练元组的个数,p为等可能假设的先验概率。