前面文章已经介绍了朴素贝叶斯算法的原理,这里基于NavieBayes算法对newsgroup文本进行分类测试。
文中代码参考:http://blog.csdn.net/jiangliqing1234/article/details/39642757
主要内容如下:
数据下载地址:http://download.csdn.net/detail/hjy321686/8057761。 文本中包含20个不同的新闻组,除其中少数文本属于多个新闻组以外,其余的文档都只属于一个新闻组。
要对文本进行分类,首先要对其进行预处理,预处理主要过程如下:
step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,所有大写字母转换成小写,可用正则表达式:String res[] = line.split("[^a-zA-Z]");
step2:去停用词,过滤对别无价值的词
step3:词根还原stemmer,基于Porter算法
预处理类如下:
package com.datamine.NaiveBayes; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.util.ArrayList; /** * Newsgroup文档预处理 * step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,所有大写字母转换成小写,可用正则表达式:String res[] = line.split("[^a-zA-Z]"); * step2:去停用词,过滤对分类无价值的词 * step3:词根还原stemmer,基于Porter算法 * @author Administrator * */ public class DataPreProcess { private static ArrayList<String> stopWordsArray = new ArrayList<String>(); /** * 输入文件的路径,处理数据 * @param srcDir 文件目录的绝对路径 * @param desDir 清洗后的文件路径 * @throws Exception */ public void doProcess(String srcDir) throws Exception{ File fileDir = new File(srcDir); if(!fileDir.exists()){ System.out.println("文件不存在!"); return ; } String subStrDir = srcDir.substring(srcDir.lastIndexOf('/')); String dirTarget = srcDir+"/../../processedSample"+subStrDir; File fileTarget = new File(dirTarget); if(!fileTarget.exists()){ //注意processedSample需要先建立目录建出来,否则会报错,因为母目录不存在 boolean mkdir = fileTarget.mkdir(); } File[] srcFiles = fileDir.listFiles(); for(int i =0 ;i <srcFiles.length;i++){ String fileFullName = srcFiles[i].getCanonicalPath(); //CanonicalPath不但是全路径,而且把..或者.这样的符号解析出来。 String fileShortName = srcFiles[i].getName(); //文件名 if(!new File(fileFullName).isDirectory()){ //确认子文件名不是目录,如果是可以再次递归调用 System.out.println("开始预处理:"+fileFullName); StringBuilder stringBuilder = new StringBuilder(); stringBuilder.append(dirTarget+"/"+fileShortName); createProcessFile(fileFullName,stringBuilder.toString()); }else{ fileFullName = fileFullName.replace("\\", "/"); doProcess(fileFullName); } } } /** * 进行文本预处理生成目标文件 * @param srcDir 源文件文件目录的绝对路径 * @param targetDir 生成目标文件的绝对路径 * @throws Exception */ private void createProcessFile(String srcDir, String targetDir) throws Exception { FileReader srcFileReader = new FileReader(srcDir); FileWriter targetFileWriter = new FileWriter(targetDir); BufferedReader srcFileBR = new BufferedReader(srcFileReader); String line,resLine; while((line = srcFileBR.readLine()) != null){ resLine = lineProcess(line); if(!resLine.isEmpty()){ //按行写,一行写一个单词 String[] tempStr = resLine.split(" "); for(int i =0; i<tempStr.length ;i++){ if(!tempStr[i].isEmpty()) targetFileWriter.append(tempStr[i]+"\n"); } } } targetFileWriter.flush(); targetFileWriter.close(); srcFileReader.close(); srcFileBR.close(); } /** * 对每行字符串进行处理,主要是词法分析、去停用词和stemming(去除时态) * @param line 待处理的一行字符串 * @param stopWordsArray 停用词数组 * @return String 处理好的一行字符串,是由处理好的单词重新生成,以空格为分隔符 */ private String lineProcess(String line) { /* * step1 * 英文词法分析,去除数字、连字符、标点符号、特殊字符, * 所有大写字符转换成小写,可以考虑使用正则表达式 */ String res[] = line.split("[^a-zA-Z]"); //step2 去停用词,大写转换成小写 //step3 Stemmer.run() String resString = new String(); for(int i=0;i<res.length;i++){ if(!res[i].isEmpty() && !stopWordsArray.contains(res[i].toLowerCase())) resString += " " + Stemmer.run(res[i].toLowerCase()) + " "; } return resString; } /** * 用stopWordsArray构造停用词的ArrayList容器 * @param stopwordsPath * @throws Exception */ private static void stopWordsToArray(String stopwordsPath) throws Exception { FileReader stopWordsReader = new FileReader(stopwordsPath); BufferedReader stopWordsBR = new BufferedReader(stopWordsReader); String stopWordsLine = null; //用stopWordsArray构造停用词的ArrayList容器 while((stopWordsLine = stopWordsBR.readLine()) != null){ if(!stopWordsLine.isEmpty()) stopWordsArray.add(stopWordsLine); } stopWordsReader.close(); stopWordsBR.close(); } public static void main(String[] args) throws Exception{ DataPreProcess dataPrePro = new DataPreProcess(); String srcDir = "E:/DataMiningSample/orginSample"; String stopwordsPath = "E:/DataMiningSample/stopwords.txt"; stopWordsToArray(stopwordsPath); dataPrePro.doProcess(srcDir); } }
/** * Stemmer中接口,将传入的word进行词根还原 * @param word 传入单词 * @return result 处理后的单词 */ public static String run(String word){ Stemmer s = new Stemmer(); char[] ch = word.toCharArray(); for (int c = 0; c < ch.length; c++) s.add(ch[c]); s.stem(); { String u; u = s.toString(); //System.out.print(u); return u; } }
方法一:保留所有词作为特征词
方法二:选取出现频率大于某一个数(3或者其他)的词作为特征词
方法三:计算每个词的权重tf*idf,根据权重来选取特征词
本文中选取方法二。
由于本文中,特征词选择采用的是方法二,可以不用对文本进行向量化,但是统计特征词出现的次数方法写在ComputeWordsVector类中,为了程序运行这里还是把文本向量化的代码贴出来。后面使用KNN算法的时候也是要用到此类的。
package com.datamine.NaiveBayes; import java.io.*; import java.util.*; /** * 计算文档的属性向量,将所有文档向量化 * @author Administrator */ public class ComputeWordsVector { /** * 计算文档的TF属性向量,TFPerDocMap * 计算TF*IDF * @param strDir 处理好的newsgroup文件目录的绝对路径 * @param trainSamplePercent 训练样本集占每个类目的比例 * @param indexOfSample 测试样例集的起始的测试样例编号 注释:通过这个参数可以将文本分成训练和测试两部分 * @param iDFPerWordMap 每个词的IDF权值属性向量 * @param wordMap 属性词典map * @throws IOException */ public void computeTFMultiIDF(String strDir,double trainSamplePercent,int indexOfSample, Map<String, Double> iDFPerWordMap,Map<String,Double> wordMap) throws IOException{ File fileDir = new File(strDir); String word; SortedMap<String,Double> TFPerDocMap = new TreeMap<String, Double>(); //注意可以用两个写文件,一个专门写测试样例,一个专门写训练样例,用sampleType的值来表示 String trainFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTrainSample"+indexOfSample; String testFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTestSample"+indexOfSample; FileWriter tsTrainWriter = new FileWriter(new File(trainFileDir)); //往训练文件中写 FileWriter tsTestWriter = new FileWriter(new File(testFileDir)); //往测试文件中写 FileWriter tsWriter = null; File[] sampleDir = fileDir.listFiles(); for(int i = 0;i<sampleDir.length;i++){ String cateShortName = sampleDir[i].getName(); System.out.println("开始计算: " + cateShortName); File[] sample = sampleDir[i].listFiles(); //测试样例集起始文件序号 double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent)); //测试样例集的结束文件序号 double testEndIndex = (indexOfSample+1)*(sample.length*(1-trainSamplePercent)); System.out.println("文件名_文件数 :" + sampleDir[i].getCanonicalPath()+"_"+sample.length); System.out.println("训练数:"+sample.length*trainSamplePercent + " 测试文本开始下标:"+ testBeginIndex+" 测试文本结束下标:"+testEndIndex); for(int j =0;j<sample.length; j++){ //计算TF,即每个词在该文件中出现的频率 TFPerDocMap.clear(); FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); String fileShortName = sample[j].getName(); Double wordSumPerDoc = 0.0;//计算每篇文档的总字数 while((word = samBR.readLine()) != null){ if(!word.isEmpty() && wordMap.containsKey(word)){ wordSumPerDoc++; if(TFPerDocMap.containsKey(word)) TFPerDocMap.put(word, TFPerDocMap.get(word)+1); else TFPerDocMap.put(word, 1.0); } } samBR.close(); /* * 遍历 TFPerDocMap,除以文档的总词数wordSumPerDoc 则得到TF * TF*IDF得到最终的特征权值,并输出到文件 * 注意:测试样例和训练样例写入的文件不同 */ if(j >= testBeginIndex && j <= testEndIndex) tsWriter = tsTestWriter; else tsWriter = tsTrainWriter; Double wordWeight; Set<Map.Entry<String, Double>> tempTF = TFPerDocMap.entrySet(); for(Iterator<Map.Entry<String, Double>> mt = tempTF.iterator();mt.hasNext();){ Map.Entry<String, Double> me = mt.next(); //由于计算IDF非常耗时,3万多个词的属性词典初步估计需要25个小时,先尝试认为所有词的IDF都是1的情况 //wordWeight = (me.getValue() / wordSumPerDoc) * iDFPerWordMap.get(me.getKey()); wordWeight = (me.getValue() / wordSumPerDoc) * 1.0; TFPerDocMap.put(me.getKey(), wordWeight); } tsWriter.append(cateShortName + " "); tsWriter.append(fileShortName + " "); Set<Map.Entry<String, Double>> tempTF2 = TFPerDocMap.entrySet(); for(Iterator<Map.Entry<String, Double>> mt = tempTF2.iterator();mt.hasNext();){ Map.Entry<String, Double> me = mt.next(); tsWriter.append(me.getKey() + " " + me.getValue()+" "); } tsWriter.append("\n"); tsWriter.flush(); } } tsTrainWriter.close(); tsTestWriter.close(); tsWriter.close(); } /** * 统计每个词的总出现次数,返回出现次数大于3词的词汇构成最终的属性词典 * @param strDir 处理好的newsgroup文件目录的绝对路径 * @param wordMap 记录出现的每个词构成的属性词典 * @return newWordMap 返回出现次数大于3次的词汇构成最终的属性词典 * @throws IOException */ public SortedMap<String, Double> countWords(String strDir, Map<String, Double> wordMap) throws IOException { File sampleFile = new File(strDir); File[] sample = sampleFile.listFiles(); String word; for(int i =0 ;i < sample.length;i++){ if(!sample[i].isDirectory()){ FileReader samReader = new FileReader(sample[i]); BufferedReader samBR = new BufferedReader(samReader); while((word = samBR.readLine()) != null){ if(!word.isEmpty() && wordMap.containsKey(word)) wordMap.put(word, wordMap.get(word)+1); else wordMap.put(word, 1.0); } samBR.close(); }else{ countWords(sample[i].getCanonicalPath(),wordMap); } } /* * 只返回出现次数大于3的单词 * 这里为了简单,应该独立一个函数,避免多次运行 */ SortedMap<String,Double> newWordMap = new TreeMap<String, Double>(); Set<Map.Entry<String, Double>> allWords = wordMap.entrySet(); for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){ Map.Entry<String, Double> me = it.next(); if(me.getValue() > 2) newWordMap.put(me.getKey(), me.getValue()); } System.out.println("newWordMap "+ newWordMap.size()); return newWordMap; } /** * 打印属性词典,到allDicWordCountMap.txt中 * @param wordMap 属性词典 * @throws IOException */ public void printWordMap(Map<String, Double> wordMap) throws IOException{ System.out.println("printWordMap:"); int countLine = 0; File outPutFile = new File("E:/DataMiningSample/docVector/allDicWordCountMap.txt"); FileWriter outPutFileWriter = new FileWriter(outPutFile); Set<Map.Entry<String, Double>> allWords = wordMap.entrySet(); for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){ Map.Entry<String, Double> me = it.next(); outPutFileWriter.write(me.getKey()+" "+me.getValue()+"\n"); countLine++; } outPutFileWriter.close(); System.out.println("WordMap size : " + countLine); } /** * 词w在整个文档集合中的逆向文档频率idf (Inverse Document Frequency), * 即文档总数n与词w所出现文件数docs(w, D)比值的对数: idf = log(n / docs(w, D)) * 计算IDF,即属性词典中每个词在多少个文档中出现过 * @param strDir 处理好的newsgroup文件目录的绝对路径 * @param wordMap 属性词典 * @return 单词的IDFMap * @throws IOException */ public SortedMap<String,Double> computeIDF(String strDir,Map<String, Double> wordMap) throws IOException{ File fileDir = new File(strDir); String word; SortedMap<String,Double> IDFPerWordMap = new TreeMap<String, Double>(); Set<Map.Entry<String, Double>> wordMapSet = wordMap.entrySet(); for(Iterator<Map.Entry<String, Double>> it = wordMapSet.iterator();it.hasNext();){ Map.Entry<String, Double> pe = it.next(); Double countDoc = 0.0; //出现字典词的文本数 Double sumDoc = 0.0; //文本总数 String dicWord = pe.getKey(); File[] sampleDir = fileDir.listFiles(); for(int i =0;i<sampleDir.length;i++){ File[] sample = sampleDir[i].listFiles(); for(int j = 0;j<sample.length;j++){ sumDoc++; //统计文本数 FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); boolean isExist = false; while((word = samBR.readLine()) != null){ if(!word.isEmpty() && word.equals(dicWord)){ isExist = true; break; } } if(isExist) countDoc++; samBR.close(); } } //计算单词的IDF //double IDF = Math.log(sumDoc / countDoc) / Math.log(10); double IDF = Math.log(sumDoc / countDoc); IDFPerWordMap.put(dicWord, IDF); } return IDFPerWordMap; } public static void main(String[] args) throws IOException { ComputeWordsVector wordsVector = new ComputeWordsVector(); String strDir = "E:\\DataMiningSample\\processedSample"; Map<String, Double> wordMap = new TreeMap<String, Double>(); //属性词典 Map<String, Double> newWordMap = new TreeMap<String, Double>(); newWordMap = wordsVector.countWords(strDir,wordMap); //wordsVector.printWordMap(newWordMap); //wordsVector.computeIDF(strDir, newWordMap); double trainSamplePercent = 0.8; int indexOfSample = 1; Map<String, Double> iDFPerWordMap = null; wordsVector.computeTFMultiIDF(strDir, trainSamplePercent, indexOfSample, iDFPerWordMap, newWordMap); //test(); } public static void test(){ double sumDoc = 18828.0; double countDoc = 229.0; double IDF1 = Math.log(sumDoc / countDoc) / Math.log(10); double IDF2 = Math.log(sumDoc / countDoc) ; System.out.println(IDF1); System.out.println(IDF2); System.out.println(Math.log(10)); } }
按指定的比例(0.9或者0.8)对整个文本进行划分,测试集和训练集
package com.datamine.NaiveBayes; import java.io.*; import java.util.*; public class CreateTrainAndTestSample { void filterSpecialWords() throws IOException{ String word; ComputeWordsVector cwv = new ComputeWordsVector(); String fileDir = "E:\\DataMiningSample\\processedSample"; SortedMap<String, Double> wordMap = new TreeMap<String, Double>(); wordMap = cwv.countWords(fileDir, wordMap); cwv.printWordMap(wordMap); //把wordMap输出到文件 File[] sampleDir = new File(fileDir).listFiles(); for(int i = 0;i<sampleDir.length;i++){ File[] sample = sampleDir[i].listFiles(); String targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName(); File targetDirFile = new File(targetDir); if(!targetDirFile.exists()){ targetDirFile.mkdir(); } for(int j = 0; j<sample.length;j++){ String fileShortName = sample[j].getName(); targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName()+"/"+fileShortName; FileWriter tgWriter = new FileWriter(targetDir); FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); while((word = samBR.readLine()) != null){ if(wordMap.containsKey(word)) tgWriter.append(word+"\n"); } tgWriter.flush(); tgWriter.close(); samBR.close(); } } } /** * 创建训练集和测试集 * @param fileDir 预处理好的文件路径 E:\DataMiningSample\processedSampleOnlySpecial\ * @param trainSamplePercent 训练集占的百分比0.8 * @param indexOfSample 一个测试集计算规则 1 * @param classifyResultFile 测试样例正确类目记录文件 * @throws IOException */ void createTestSample(String fileDir,double trainSamplePercent,int indexOfSample,String classifyResultFile) throws IOException{ String word,targetDir; FileWriter crWriter = new FileWriter(classifyResultFile);//测试样例正确类目记录文件 File[] sampleDir = new File(fileDir).listFiles(); for(int i =0;i<sampleDir.length;i++){ File[] sample = sampleDir[i].listFiles(); double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent)); double testEndIndex = (indexOfSample + 1)*(sample.length*(1-trainSamplePercent)); for(int j = 0;j<sample.length;j++){ FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); String fileShortName = sample[j].getName(); String subFileName = fileShortName; if(j > testBeginIndex && j < testEndIndex){ targetDir = "E:/DataMiningSample/TestSample"+indexOfSample+"/"+sampleDir[i].getName(); crWriter.append(subFileName + " "+sampleDir[i].getName()+"\n"); }else{ targetDir = "E:/DataMiningSample/TrainSample"+indexOfSample+"/"+sampleDir[i].getName(); } targetDir = targetDir.replace("\\", "/"); File trainSamFile = new File(targetDir); if(!trainSamFile.exists()){ trainSamFile.mkdir(); } targetDir += "/" + subFileName; FileWriter tsWriter = new FileWriter(new File(targetDir)); while((word = samBR.readLine()) != null) tsWriter.append(word+"\n"); tsWriter.flush(); tsWriter.close(); samBR.close(); } } crWriter.close(); } public static void main(String[] args) throws IOException { CreateTrainAndTestSample test = new CreateTrainAndTestSample(); String fileDir = "E:/DataMiningSample/processedSampleOnlySpecial"; double trainSamplePercent=0.8; int indexOfSample=1; String classifyResultFile="E:/DataMiningSample/classifyResult"; test.createTestSample(fileDir, trainSamplePercent, indexOfSample, classifyResultFile); //test.filterSpecialWords(); } }
package com.datamine.NaiveBayes; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.math.BigDecimal; import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; import java.util.Vector; /** * 利用朴素贝叶斯算法对newsgroup文档集做分类,采用十组交叉测试取平均值 * 采用多项式模型 * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集总单词数) * @author Administrator */ public class NaiveBayesianClassifier { /** * 用朴素贝叶斯算法对测试文档集分类 * @param trainDir 训练文档集目录 * @param testDir 测试文档集目录 * @param classifyResultFileNew 分类结果文件路径 * @throws Exception */ private void doProcess(String trainDir,String testDir, String classifyResultFileNew) throws Exception{ //保存训练集中每个类别的总词数 <目录名,单词总数> category Map<String,Double> cateWordsNum = new TreeMap<String, Double>(); //保存训练样本中每个类别中每个属性词的出现次数 <类目_单词,数目> Map<String,Double> cateWordsProb = new TreeMap<String, Double>(); cateWordsNum = getCateWordsNum(trainDir); cateWordsProb = getCateWordsProb(trainDir); double totalWordsNum = 0.0;//记录所有训练集的总词数 Set<Map.Entry<String, Double>> cateWordsNumSet = cateWordsNum.entrySet(); for(Iterator<Map.Entry<String, Double>> it = cateWordsNumSet.iterator();it.hasNext();){ Map.Entry<String, Double> me = it.next(); totalWordsNum += me.getValue(); } //下面开始读取测试样例做分类 Vector<String> testFileWords = new Vector<String>(); //测试样本所有词的容器 String word; File[] testDirFiles = new File(testDir).listFiles(); FileWriter crWriter = new FileWriter(classifyResultFileNew); for(int i =0;i<testDirFiles.length;i++){ File[] testSample = testDirFiles[i].listFiles(); for(int j =0;j<testSample.length;j++){ testFileWords.clear(); FileReader spReader = new FileReader(testSample[j]); BufferedReader spBR = new BufferedReader(spReader); while((word = spBR.readLine()) != null){ testFileWords.add(word); } spBR.close(); //下面分别计算该测试样例属于二十个类别的概率 File[] trainDirFiles = new File(trainDir).listFiles(); BigDecimal maxP = new BigDecimal(0); String bestCate = null; for(int k =0; k < trainDirFiles.length; k++){ BigDecimal p = computeCateProb(trainDirFiles[k],testFileWords,cateWordsNum,totalWordsNum,cateWordsProb); if( k == 0){ maxP = p; bestCate = trainDirFiles[k].getName(); continue; } if(p.compareTo(maxP) == 1){ maxP = p; bestCate = trainDirFiles[k].getName(); } } crWriter.append(testSample[j].getName() + " " + bestCate + "\n"); crWriter.flush(); } } crWriter.close(); } /** * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集中总单词数) * 计算某一个测试样本数据某个类别的概率 使用多项式模型 * @param trainFile 该类别所有的训练样本所在的目录 * @param testFileWords 该测试样本中的所有词构成的容器 * @param cateWordsNum 记录每个目录下单词的总数 * @param totalWordsNum 所有训练样本的单词的总数 * @param cateWordsProb 记录每个目录中出现单词和次数 * @return 返回该测试样本在该类别中的概率 */ private BigDecimal computeCateProb(File trainFile, Vector<String> testFileWords, Map<String, Double> cateWordsNum, double totalWordsNum, Map<String, Double> cateWordsProb) { BigDecimal probability = new BigDecimal(1); double wordNumInCate = cateWordsNum.get(trainFile.getName()); BigDecimal wordNumInCateBD = new BigDecimal(wordNumInCate); BigDecimal totalWordsNumBD = new BigDecimal(totalWordsNum); for(Iterator<String> it = testFileWords.iterator();it.hasNext();){ String me = it.next(); String key = trainFile.getName()+"_"+me; double testFileWordNumInCate; if(cateWordsProb.containsKey(key)) testFileWordNumInCate = cateWordsProb.get(key); else testFileWordNumInCate = 0.0; BigDecimal testFileWordNumInCateBD = new BigDecimal(testFileWordNumInCate); BigDecimal xcProb = (testFileWordNumInCateBD.add(new BigDecimal(0.0001))) .divide(wordNumInCateBD.add(totalWordsNumBD), 10, BigDecimal.ROUND_CEILING); probability = probability.multiply(xcProb); } // P = P(tk|c)*P(C) BigDecimal result = probability.multiply(wordNumInCateBD.divide(totalWordsNumBD,10, BigDecimal.ROUND_CEILING)); return result; } /** * 统计某个类训练样本中每个单词出现的次数 * @param trainDir 训练样本集目录 * @return cateWordsProb 用"类目_单词"来索引map,value就是该类目下该单词出现的次数 * @throws Exception */ private Map<String, Double> getCateWordsProb(String trainDir) throws Exception { Map<String,Double> cateWordsProb = new TreeMap<String, Double>(); File sampleFile = new File(trainDir); File[] sampleDir = sampleFile.listFiles(); String word; for(int i =0;i < sampleDir.length;i++){ File[] sample = sampleDir[i].listFiles(); for(int j =0;j<sample.length;j++){ FileReader samReader = new FileReader(sample[j]); BufferedReader samBR = new BufferedReader(samReader); while((word = samBR.readLine()) != null){ String key = sampleDir[i].getName()+"_"+word; if(cateWordsProb.containsKey(key)) cateWordsProb.put(key, cateWordsProb.get(key)+1); else cateWordsProb.put(key, 1.0); } samBR.close(); } } return cateWordsProb; } /** * 获得每个类目下的单词总数 * @param trainDir 训练文档集目录 * @return cateWordsNum <目录名,单词总数>的map * @throws IOException */ private Map<String, Double> getCateWordsNum(String trainDir) throws IOException { Map<String, Double> cateWordsNum = new TreeMap<String, Double>(); File[] sampleDir = new File(trainDir).listFiles(); for(int i =0;i<sampleDir.length;i++){ double count = 0; File[] sample = sampleDir[i].listFiles(); for(int j =0;j<sample.length;j++){ FileReader spReader = new FileReader(sample[j]); BufferedReader spBR = new BufferedReader(spReader); while(spBR.readLine() != null){ count++; } spBR.close(); } cateWordsNum.put(sampleDir[i].getName(), count); } return cateWordsNum; } /** * 根据正确类目文件和分类结果文件统计出准确率 * @param classifyRightCate 正确类目文件 <文件名,类别目录名> * @param classifyResultFileNew 分类结果文件 <文件名,类别目录名> * @return 分类的准确率 * @throws Exception */ public double computeAccuracy(String classifyRightCate, String classifyResultFileNew) throws Exception { Map<String,String> rightCate = new TreeMap<String, String>(); Map<String,String> resultCate = new TreeMap<String,String>(); rightCate = getMapFromResultFile(classifyRightCate); resultCate = getMapFromResultFile(classifyResultFileNew); Set<Map.Entry<String, String>> resCateSet = resultCate.entrySet(); double rightCount = 0.0; for(Iterator<Map.Entry<String, String>> it = resCateSet.iterator();it.hasNext();){ Map.Entry<String, String> me = it.next(); if(me.getValue().equals(rightCate.get(me.getKey()))) rightCount++; } computerConfusionMatrix(rightCate,resultCate); return rightCount / resultCate.size(); } /** * 根据正确类目文件和分类结果文件计算混淆矩阵并输出 * @param rightCate 正确类目map * @param resultCate 分类结果对应map */ private void computerConfusionMatrix(Map<String, String> rightCate, Map<String, String> resultCate) { int[][] confusionMatrix = new int[20][20]; //首先求出类目对应的数组索引 SortedSet<String> cateNames = new TreeSet<String>(); Set<Map.Entry<String, String>> rightCateSet = rightCate.entrySet(); for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){ Map.Entry<String, String> me = it.next(); cateNames.add(me.getValue()); } cateNames.add("rec.sport.baseball");//防止数少一个类目 String[] cateNamesArray = cateNames.toArray(new String[0]); Map<String,Integer> cateNamesToIndex = new TreeMap<String, Integer>(); for(int i =0;i<cateNamesArray.length;i++){ cateNamesToIndex.put(cateNamesArray[i], i); } for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){ Map.Entry<String, String> me = it.next(); confusionMatrix[cateNamesToIndex.get(me.getValue())][cateNamesToIndex.get(resultCate.get(me.getKey()))]++; } //输出混淆矩阵 double[] hangSum = new double[20]; System.out.print(" "); for(int i=0;i<20;i++){ System.out.printf("%-6d",i); } System.out.println("准确率"); for(int i =0;i<20;i++){ System.out.printf("%-6d",i); for(int j = 0;j<20;j++){ System.out.printf("%-6d",confusionMatrix[i][j]); hangSum[i] += confusionMatrix[i][j]; } System.out.printf("%-6f\n",confusionMatrix[i][i]/hangSum[i]); } System.out.println(); } /** * 从结果文件中读取Map * @param file 类目文件 * @return Map<String,String> 由<文件名,类目名>保存的map * @throws Exception */ private Map<String, String> getMapFromResultFile(String file) throws Exception { File crFile = new File(file); FileReader crReader = new FileReader(crFile); BufferedReader crBR = new BufferedReader(crReader); Map<String,String> res = new TreeMap<String, String>(); String[] s; String line; while((line = crBR.readLine()) != null){ s = line.split(" "); res.put(s[0], s[1]); } return res; } public static void main(String[] args) throws Exception { CreateTrainAndTestSample ctt = new CreateTrainAndTestSample(); NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier(); //根据包含非特征词的文档集生成只包含特征词的文档集到processedSampleOnlySpecial目录下 ctt.filterSpecialWords(); double[] accuracyOfEveryExp = new double[10]; double accuracyAvg,sum = 0; for(int i =0;i<10;i++){//用交叉验证法做十次分类实验,对准确率取平均值 String TrainDir = "E:/DataMiningSample/TrainSample"+i; String TestDir = "E:/DataMiningSample/TestSample"+i; String classifyRightCate = "E:/DataMiningSample/classifyRightCate"+i+".txt"; String classifyResultFileNew = "E:/DataMiningSample/classifyResultNew"+i+".txt"; ctt.createTestSample("E:/DataMiningSample/processedSampleOnlySpecial", 0.8, i, classifyRightCate); nbClassifier.doProcess(TrainDir, TestDir, classifyResultFileNew); accuracyOfEveryExp[i] = nbClassifier.computeAccuracy(classifyRightCate,classifyResultFileNew); System.out.println("The accuracy for Naive Bayesian Classifier in "+i+"th Exp is :" + accuracyOfEveryExp[i]); } for(int i =0;i<10;i++) sum += accuracyOfEveryExp[i]; accuracyAvg = sum/10; System.out.println("The average accuracy for Naive Bayesian Classifier in all Exps is :" + accuracyAvg); } }
结果(略)
这里使用的多项式模型是经过改进的计算方法:改进多项式模型的类条件概率的计算公式,改进为 类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+0.001)/(类c下单词总数+训练样本中不重复特征词总数),分子当tk没有出现时,只加0.001,这样更加精确的描述的词的统计分布规律
为了进一步提高朴素贝叶斯算法的分类可以进行如下改进:
1、优化特征词选取的方法,如方法三,或者更优方法
2、改进多项式模型的类条件概率的计算公式(上面已经实现)