基于朴素贝叶斯的兴趣分类

寒假期间使用了朴素贝叶斯算法对用户画像中的用户兴趣进行了分类,然而最终预测的准确率以及召回率却不尽如人意,这里就谈谈朴素贝叶斯算法的使用以及我对此次失败的一个小小反思吧~

关于分类

分类是将一个未知样本分到几个预先已知类的过程。在众多的分类模型中,应用最为广泛的两种分类模型是决策树模型(Decision Tree Model)和朴素贝叶斯模型(Naive Bayesian Model,NBC)。

分类模型 优点 缺点
决策树 根据决策树可以很容易地构造出规则,而规则通常易于解释和理解;决策树可很好地扩展到大型数据库中,同时它的大小独立于数据库的大小;决策树模型的另外一大优点就是可以对有许多属性的数据集构造决策树。 处理缺失数据时的困难,过度拟合问题的出现,以及忽略数据集中属性之间的相关性等。
朴素贝叶斯 有稳定的分类效率。对小规模的数据表现很好,能个处理多分类任务,适合增量式训练,尤其是数据量超出内存时,我们可以一批批的去增量训练。对缺失数据不太敏感,算法也比较简单,常用于文本分类。 需要知道先验概率,且先验概率很多时候取决于假设,假设的模型可以有很多种,因此在某些时候会由于假设的先验模型的原因导致预测效果不佳。

关于朴素贝叶斯

朴素贝叶斯法是基于贝叶斯定理与特征条件独立假设的分类方法。朴素贝叶斯分类器模型会给问题实例分配用特征值表示的类标签,类标签取自有限集合。朴素贝叶斯分类器基于一个简单的假定:给定目标值时属性之间相互条件独立。举个例子,如果一种水果其具有红,圆,直径大概3英寸等特征,该水果可以被判定为是苹果。尽管这些特征相互依赖或者有些特征由其他特征决定,然而朴素贝叶斯分类器认为这些属性在判定该水果是否为苹果的概率分布上独立的。
关于朴素贝叶斯的原理网上资料很多,具体内容可以参考此博客:http://www.cnblogs.com/leoo2sk/archive/2010/09/17/naive-bayesian-classifier.html

项目思路

1.确定分类

我们是按照微博的兴趣分类将兴趣分为24个类别,并赋予特征词大约每个类别500个,将每个类别放置在一个txt文件中
基于朴素贝叶斯的兴趣分类_第1张图片

2.获取标签特征

//获取标签特征词(!!发现隐藏文件,由于使用过gedit打开过),形成一个vocabulary
        List<String > vocabulary = new ArrayList<String>();
        File dir = new File("/home/hadoop/项目内容/类别库");
        File[] files = dir.listFiles();   //获取不同类别的标签文件
        System.out.println(files.length);
        StringBuilder sb = new StringBuilder();
        for(File file : files){
            BufferedReader br = new BufferedReader(new FileReader(file));
            String line = null;
            while((line =  br.readLine()) != null){
                sb.append(line);  //按"`"分割不同类别的标签
            }
            sb.append(line + "`");
        }
        String[] tags = sb.toString().trim().split("`");
        List<String> newTags = new ArrayList<String>();
        for(String tag: tags){
            if(tag.length() > 4){
                newTags.add(tag); //去除空行标签
            }
        }
        Object[] newtags = newTags.toArray();
        ListString>> list = new ArrayListString>>();  //记录每类中的标签
        for(int i = 0; iString> classWithTags = new Tuple2String>(i,(String)newtags[i]);
            System.out.println(classWithTags);
            list.add(classWithTags);
            String[] tokens = ((String)newtags[i]).split("/");
            for(String tag: tokens){
                vocabulary.add(tag);
            }
        }

3.获取训练样本

此处的训练样本则为各个分类的特征词~

//获取训练样本
        JavaPairRDD<Integer,String> trainRDD = sc.parallelizePairs(list);   //将每类的标签词转化为RDD
        JavaPairRDD<Integer,String> trainSetRDD = trainRDD.mapValues(new ToTrainSet(vocabulary));  //将标签词转化为向量模型
        List<Tuple2<Integer,String>> trainSet = trainSetRDD.collect();
        writeTrainSet(trainSet);  //写成libsvm文件格式,以方便训练
        System.out.println("trainset is ok");

4.训练模型

//读取训练集并训练模型
        String path = "./trainset";
        JavaRDD trainData = MLUtils.loadLibSVMFile(sc.sc(),path).toJavaRDD();
        model = NaiveBayes.train(trainData.rdd(),1.0);
        System.out.println("model is ok");

5.预测测试集

测试集是挑选每个用户关注的大V及其简介, 利用HanLP进行切词

//读取txt文件并利用hanlp切词形成一个list,作为测试样本
        String filepath = "/home/hadoop/result/19岁金鱼想当歌手.txt";
        File inputfile = new File(filepath);
        rewriteFile(filepath);
        String content = readTxtFile(inputfile); //获取txt的内容
        System.out.print(content);
        Segment segment = HanLP.newSegment();
        List termList = segment.seg(content); //用Hanlp进行切词形成List
        for(Term term : termList){
            testStr += term.word + " ";  //使用term.word去掉属性部分
        }
//预测测试集
        double[] testArray = sentenceToArrays(vocabulary,testStr);
        writeTestSet(testArray);
        String testPath = "./testset";
        JavaRDD testData = MLUtils.loadLibSVMFile(sc.sc(), testPath).toJavaRDD();

6.选取概率值前三的兴趣存档

即每个用户确定3个兴趣

预测效果

对每个用户的兴趣进行分类后,为了验证算法的准确率以及召回率,我是随机选择了500个用户进行人工标注其兴趣,兴趣的个数不限,后利用公式

代码实现如下:

Map<String,String> human = new HashMap<String,String>();
        Map<String,String> predict = new HashMap<String,String>();
        String humanpath = "/home/hadoop/Seeing项目内容/result/人工标注/人工标注用户";
        String predictpath = "/home/hadoop/Seeing项目内容/result/人工标注/新机器";
        //将人工标注的放于human中
        BufferedReader human_br = new BufferedReader(new FileReader(new File(humanpath)));
        String human_interest = "";
        while((human_interest = human_br.readLine()) != null){
            String[] content = human_interest.split("   ");
            human.put(content[0], content[1]);      
            }
        //将机器标注的放于predict中
        BufferedReader predict_br = new BufferedReader(new FileReader(new File(predictpath)));
        String predict_interest = "";
        while((predict_interest = predict_br.readLine()) != null){
            String[] content1 = predict_interest.split("    ");
            predict.put(content1[0], content1[1]);
        }
        //计算准确率和召回率
        //先转成set
        Set human_set = human.entrySet();
        Set predict_set = predict.entrySet();
        int i  = 0;
        double accuracy = 0;
        double recall = 0;
        for(Iterator iter1 = predict_set.iterator(); iter1.hasNext();){
            String[] predictinterests = {};
            String[] humaninterests = {};
            //交集个数
            double  result_insect =0;
            Map.Entry<String , String> entry1 = (Map.Entry<String, String>) iter1.next();
            String predictkey = entry1.getKey();
            String predictvalue = entry1.getValue();
            predictinterests = predictvalue.split(",");
            for(Iterator iter2 = human_set.iterator();iter2.hasNext();){
                Map.Entry<String, String> entry2 = (Map.Entry<String, String>) iter2.next();
                String humankey = entry2.getKey();
                String humanvalue = entry2.getValue();
//              System.out.println(humanvalue);
                humaninterests = humanvalue.split(",");
                if(humankey.equals(predictkey)){
                    //求交集
                    i++;
                    result_insect = intersect(predictinterests, humaninterests);
//                  System.out.println(result_insect);
                    double a = predictinterests.length;
                    double b = humaninterests.length;
                    accuracy = (result_insect/predictinterests.length + accuracy*(i-1))/i;
                    recall =( result_insect/humaninterests.length + recall *(i-1))/i;
//                  System.out.println(i+ " "+ result_insect + " "+ a + " "+ b);
                }

结果如图:

反思

准确率和召回率都惊人的低,个人感觉原因应该有以下几点:
1.人工标注会有些疏忽,另外就是微博兴趣的特征词所构建的词典会随时间的问题有一些特征词不再适用;
2.朴素贝叶斯本身的缺点导致,由于我们是通过先验和数据来决定后验的概率从而决定分类,所以分类决策存在一定的错误率。
3.朴素贝叶斯模型假设属性之间相互独立,这个假设在实际应用中往往是不成立的,在属性个数比较多或者属性之间相关性较大时,分类效果不好。而在属性相关性较小时,朴素贝叶斯性能最为良好。对于这一点,有半朴素贝叶斯之类的算法通过考虑部分关联性适度改进。

接下来会继续改进兴趣分类的实现,各位大佬走过路过也可以提提建议~谢谢O(∩_∩)O

你可能感兴趣的:(机器学习,算法)