数据挖掘—朴素贝叶斯分类算法(Java实现)

算法描述

(1)扫描训练样本数据集,分别统计训练集中类别 Ci 的个数 Di 和属于类别Ci 的样本中属性Ak取值Xk为 Dik 的实例样本个数,构成统计表;
(2)计算先验概率和条件概率,构成概率表;
(3)构建分类模型;
(4)扫描待分类的样本数据集,调用已得到的统计表、概率表以及构建好的分类准则,得出分类结果;

代码

public class Bayes {

    //将训练集按巡逻集合的最后一个值进行分类
    Map<String, ArrayList<ArrayList<String>>> datasOfClass(ArrayList<ArrayList<String>> datas){

        Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>();
        ArrayList<String> t = null;
        String c = "";
        for (int i = 0; i < datas.size(); i++) {
            t = datas.get(i);
            c = t.get(t.size() - 1);
            if(c.length()==0) continue;
            if (map.containsKey(c)) {
                map.get(c).add(t);
            } else {
                ArrayList<ArrayList<String>> nt = new ArrayList<ArrayList<String>>();
                nt.add(t);
                map.put(c, nt);
            }
        }

        return map;
    }

    //在训练数据的基础上预测测试元组的类别 ,testT的各个属性在结果集里面出现的概率相乘最高的,即是结果
    public String predictClass(ArrayList<ArrayList<String>> datas, ArrayList<String> testT) {
        Map<String, ArrayList<ArrayList<String>>> doc = this.datasOfClass(datas);
        //将训练集元素划分保存在数据里
        Object classes[] = doc.keySet().toArray();
        double maxP = 0.00;
        int maxPIndex = -1;
        //testT的各个属性在结果集里面出现的概率相乘最高的,即使结果集
        for (int i = 0; i < doc.size(); i++) {
            String c = classes[i].toString();
            ArrayList<ArrayList<String>> d = doc.get(c);
            BigDecimal b1 = new BigDecimal(Double.toString(d.size()));
            BigDecimal b2 = new BigDecimal(Double.toString(datas.size()));
            //b1除以b2得到一个精度为3的双浮点数
            double pOfC = b1.divide(b2,10,BigDecimal.ROUND_HALF_UP).doubleValue();
            for (int j = 0; j < testT.size(); j++) {
                double pv = this.pOfV(d, testT.get(j), j);
                if(pv==0) pv=1/(double)d.size();
                BigDecimal b3 = new BigDecimal(Double.toString(pOfC));
                BigDecimal b4 = new BigDecimal(Double.toString(pv));
                //b3乘以b4得到一个浮点数
                pOfC=b3.multiply(b4).doubleValue();
            }
            if(pOfC > maxP){
                maxP = pOfC;
                maxPIndex = i;
            }
        }
        return classes[maxPIndex].toString();
    }

    // 计算指定属性到训练集出现的频率
    private double pOfV(ArrayList<ArrayList<String>> d, String value, int index) {
        double p = 0.00;
        int count = 0;
        int total = d.size();
        for (int i = 0; i < total; i++) {
            if(Double.parseDouble(d.get(i).get(index))==Double.parseDouble(value)){
                count++;
            }
        }
        BigDecimal b1 = new BigDecimal(Double.toString(count));
        BigDecimal b2 = new BigDecimal(Double.toString(total));

        //b1除以b2得到一个精度为3的双浮点数
        p = b1.divide(b2,10,BigDecimal.ROUND_HALF_UP).doubleValue();
        return p;
    }
}


public class TestBayes {

    static String res;
    //读取测试元组
    public ArrayList<String> readTestData(String string) throws IOException{
        ArrayList<String> candAttr = new ArrayList<String>();

        String str = "";
        str = string;
        //string分析器
            String[] tokenizer = str.split(",");
            for(int i=0;i<tokenizer.length-1;i++){

                candAttr.add(tokenizer[i]);

            }
            res=tokenizer[tokenizer.length-1];
        return candAttr;
    }

    //读取训练集
    public ArrayList<ArrayList<String>> readData() throws IOException {
        ArrayList<ArrayList<String>> list=new ArrayList<>();
        try { // 防止文件建立或读取失败,用catch捕捉错误并打印,也可以throw

            /* 读入TXT文件 */

            File filename = new File("src/bp/trainBayes.txt"); // 要读取以上路径的input。txt文件
            InputStreamReader reader = new InputStreamReader(
                    new FileInputStream(filename)); // 建立一个输入流对象reader
            BufferedReader br = new BufferedReader(reader); // 建立一个对象,它把文件内容转成计算机能读懂的语言
            String line = "";
            line = br.readLine();
            while (line != null) {

                String[] temp=line.split(",");
                ArrayList<String> arrayList = new ArrayList<>(Arrays.asList(temp));
                list.add(arrayList);
                line = br.readLine();
            }

        } catch (Exception e) {
            e.printStackTrace();
        }
        return list;
    }
    public static  List<String> readTxt(String fileName){
        List<String> list=new ArrayList<>();
        try { // 防止文件建立或读取失败,用catch捕捉错误并打印,也可以throw

            /* 读入TXT文件 */

            File filename = new File(fileName); // 要读取以上路径的input。txt文件
            InputStreamReader reader = new InputStreamReader(
                    new FileInputStream(filename)); // 建立一个输入流对象reader
            BufferedReader br = new BufferedReader(reader); // 建立一个对象,它把文件内容转成计算机能读懂的语言
            String line = "";
            line = br.readLine();
            while (line != null) {
                if(line.length()>0){

                     list.add(line);}
                line = br.readLine();
            }

        } catch (Exception e) {
            e.printStackTrace();
        }


        return list;
    }
    public static void main(String[] args) {
        TestBayes tb = new TestBayes();
        int righr=0,total;
        ArrayList<ArrayList<String>> datas = null;
        ArrayList<String> testT = null;
        Bayes bayes = new Bayes();
        try {
            datas = tb.readData();

            List <String> l=  readTxt("src/bp/testBayes.txt");
               for(String c:l) {

                   testT = tb.readTestData(c);
                   String k = bayes.predictClass(datas, testT);

                    if(k.equals(res))
                        righr++;
               }
               double re=(double) righr/(double)l.size();
            System.out.println("测试集的数量:"+ (new Double(l.size())).intValue());
            System.out.println("分类正确的数量:"+(new Double(righr)).intValue());
            System.out.println("算法的分类正确率为:"+ re);

        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

你可能感兴趣的:(数据挖掘,算法,机器学习,java,大数据)