文本分类——NaiveBayes

前面文章已经介绍了朴素贝叶斯算法的原理,这里基于NavieBayes算法对newsgroup文本进行分类测试。

文中代码参考:http://blog.csdn.net/jiangliqing1234/article/details/39642757

主要内容如下:

1、newsgroup数据集介绍

数据下载地址:http://download.csdn.net/detail/hjy321686/8057761。  文本中包含20个不同的新闻组,除其中少数文本属于多个新闻组以外,其余的文档都只属于一个新闻组。

2、newsgroup数据预处理

要对文本进行分类,首先要对其进行预处理,预处理主要过程如下:

step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,所有大写字母转换成小写,可用正则表达式:String res[] = line.split("[^a-zA-Z]");

step2:去停用词,过滤对别无价值的词

step3:词根还原stemmer,基于Porter算法

预处理类如下:

package com.datamine.NaiveBayes;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.ArrayList;

/**
 * Newsgroup文档预处理
 * step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,所有大写字母转换成小写,可用正则表达式:String res[] = line.split("[^a-zA-Z]");
 * step2:去停用词,过滤对分类无价值的词
 * step3:词根还原stemmer,基于Porter算法
 * @author Administrator
 *
 */
public class DataPreProcess {

	private static ArrayList<String> stopWordsArray = new ArrayList<String>();
	
	/**
	 * 输入文件的路径,处理数据
	 * @param srcDir 文件目录的绝对路径
	 * @param desDir 清洗后的文件路径
	 * @throws Exception
	 */
	public void doProcess(String srcDir) throws Exception{
		
		File fileDir = new File(srcDir);
		if(!fileDir.exists()){
			System.out.println("文件不存在!");
			return ;
		}
		
		String subStrDir = srcDir.substring(srcDir.lastIndexOf('/'));
		String dirTarget = srcDir+"/../../processedSample"+subStrDir;
		File fileTarget = new File(dirTarget);
		
		if(!fileTarget.exists()){
			//注意processedSample需要先建立目录建出来,否则会报错,因为母目录不存在
			boolean mkdir = fileTarget.mkdir();
		}
		
		File[] srcFiles = fileDir.listFiles();
		
		for(int i =0 ;i <srcFiles.length;i++){
			
			String fileFullName = srcFiles[i].getCanonicalPath(); //CanonicalPath不但是全路径,而且把..或者.这样的符号解析出来。
			String fileShortName = srcFiles[i].getName(); //文件名
			
			if(!new File(fileFullName).isDirectory()){ //确认子文件名不是目录,如果是可以再次递归调用
				System.out.println("开始预处理:"+fileFullName);
				StringBuilder stringBuilder = new StringBuilder();
				stringBuilder.append(dirTarget+"/"+fileShortName);
				
				createProcessFile(fileFullName,stringBuilder.toString());
				
			}else{
				fileFullName = fileFullName.replace("\\", "/");
				doProcess(fileFullName);
			}
		}
	}
	
	/**
	 * 进行文本预处理生成目标文件
	 * @param srcDir 源文件文件目录的绝对路径
	 * @param targetDir 生成目标文件的绝对路径
	 * @throws Exception 
	 */
	private void createProcessFile(String srcDir, String targetDir) throws Exception {
		
		FileReader srcFileReader = new FileReader(srcDir);
		FileWriter targetFileWriter = new FileWriter(targetDir);
		BufferedReader srcFileBR = new BufferedReader(srcFileReader);
		String line,resLine;
		
		while((line = srcFileBR.readLine()) != null){
			resLine = lineProcess(line);
			if(!resLine.isEmpty()){
				//按行写,一行写一个单词
				String[] tempStr = resLine.split(" ");
				for(int i =0; i<tempStr.length ;i++){
					if(!tempStr[i].isEmpty())
						targetFileWriter.append(tempStr[i]+"\n");
				}
			}
		}
		
		targetFileWriter.flush();
		targetFileWriter.close();
		srcFileReader.close();
		srcFileBR.close();
		
	}

	/**
	 * 对每行字符串进行处理,主要是词法分析、去停用词和stemming(去除时态)
	 * @param line 待处理的一行字符串
	 * @param stopWordsArray 停用词数组
	 * @return String 处理好的一行字符串,是由处理好的单词重新生成,以空格为分隔符
	 */
	private String lineProcess(String line) {
		
		/*
		 * step1 
		 * 英文词法分析,去除数字、连字符、标点符号、特殊字符,
		 * 所有大写字符转换成小写,可以考虑使用正则表达式
		 */
		String res[] = line.split("[^a-zA-Z]");
		
		//step2 去停用词,大写转换成小写 
		//step3 Stemmer.run()
		String resString = new String();
		
		for(int i=0;i<res.length;i++){
			if(!res[i].isEmpty() && !stopWordsArray.contains(res[i].toLowerCase()))
				resString += " " + Stemmer.run(res[i].toLowerCase()) + " ";
		}
		
		return resString;
	}

	/**
	 * 用stopWordsArray构造停用词的ArrayList容器
	 * @param stopwordsPath
	 * @throws Exception 
	 */
	private static void stopWordsToArray(String stopwordsPath) throws Exception {
		
		FileReader stopWordsReader = new FileReader(stopwordsPath);
		BufferedReader stopWordsBR = new BufferedReader(stopWordsReader);
		String stopWordsLine = null;
		
		//用stopWordsArray构造停用词的ArrayList容器
		while((stopWordsLine = stopWordsBR.readLine()) != null){
			if(!stopWordsLine.isEmpty())
				stopWordsArray.add(stopWordsLine);
		}
		
		stopWordsReader.close();
		stopWordsBR.close();
	}
	
	public static void main(String[] args) throws Exception{
		
		DataPreProcess dataPrePro = new DataPreProcess();
		String srcDir = "E:/DataMiningSample/orginSample";
		
		String stopwordsPath = "E:/DataMiningSample/stopwords.txt";
		
		stopWordsToArray(stopwordsPath);
		
		dataPrePro.doProcess(srcDir);
	}

	
}

对于step3中的Porter算法可以网上下载,这里我基于其之上添加了一个run()方法。

	/**
	 * Stemmer中接口,将传入的word进行词根还原
	 * @param word 传入单词
	 * @return result 处理后的单词
	 */
	public static String run(String word){
		
		Stemmer s = new Stemmer();
		
		char[] ch = word.toCharArray();
		
		for (int c = 0; c < ch.length; c++)
			s.add(ch[c]);
		
		s.stem();
		{
			String u;
			u = s.toString();
			//System.out.print(u);
			return u;
		}
		
	}

3、特征项选择

方法一:保留所有词作为特征词

方法二:选取出现频率大于某一个数(3或者其他)的词作为特征词

方法三:计算每个词的权重tf*idf,根据权重来选取特征词

本文中选取方法二。

4、文本向量化

由于本文中,特征词选择采用的是方法二,可以不用对文本进行向量化,但是统计特征词出现的次数方法写在ComputeWordsVector类中,为了程序运行这里还是把文本向量化的代码贴出来。后面使用KNN算法的时候也是要用到此类的。

package com.datamine.NaiveBayes;

import java.io.*;
import java.util.*;

/**
 * 计算文档的属性向量,将所有文档向量化
 * @author Administrator
 */
public class ComputeWordsVector {

	/**
	 * 计算文档的TF属性向量,TFPerDocMap
	 * 计算TF*IDF
	 * @param strDir 处理好的newsgroup文件目录的绝对路径
	 * @param trainSamplePercent 训练样本集占每个类目的比例
	 * @param indexOfSample 测试样例集的起始的测试样例编号      注释:通过这个参数可以将文本分成训练和测试两部分
	 * @param iDFPerWordMap  每个词的IDF权值属性向量
	 * @param wordMap 属性词典map
	 * @throws IOException 
	 */
	public void computeTFMultiIDF(String strDir,double trainSamplePercent,int indexOfSample,
			Map<String, Double> iDFPerWordMap,Map<String,Double> wordMap) throws IOException{
		
		File fileDir = new File(strDir);
		String word;
		SortedMap<String,Double> TFPerDocMap = new TreeMap<String, Double>();
		//注意可以用两个写文件,一个专门写测试样例,一个专门写训练样例,用sampleType的值来表示
		String trainFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTrainSample"+indexOfSample;
		String testFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTestSample"+indexOfSample;
		
		FileWriter tsTrainWriter = new FileWriter(new File(trainFileDir)); //往训练文件中写
		FileWriter tsTestWriter = new FileWriter(new File(testFileDir)); //往测试文件中写
		
		FileWriter tsWriter = null;
		File[] sampleDir = fileDir.listFiles();
		
		for(int i = 0;i<sampleDir.length;i++){
			
			String cateShortName = sampleDir[i].getName();
			System.out.println("开始计算: " + cateShortName);
			
			File[] sample = sampleDir[i].listFiles();
			//测试样例集起始文件序号
			double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent));
			//测试样例集的结束文件序号
			double testEndIndex = (indexOfSample+1)*(sample.length*(1-trainSamplePercent));
			System.out.println("文件名_文件数 :" + sampleDir[i].getCanonicalPath()+"_"+sample.length);
			System.out.println("训练数:"+sample.length*trainSamplePercent
					+ " 测试文本开始下标:"+ testBeginIndex+" 测试文本结束下标:"+testEndIndex);
			
			for(int j =0;j<sample.length; j++){
				
				//计算TF,即每个词在该文件中出现的频率
				TFPerDocMap.clear();
				FileReader samReader = new FileReader(sample[j]);
				BufferedReader samBR = new BufferedReader(samReader);
				String fileShortName = sample[j].getName();
				Double wordSumPerDoc = 0.0;//计算每篇文档的总字数
				while((word = samBR.readLine()) != null){
					
					if(!word.isEmpty() && wordMap.containsKey(word)){
						wordSumPerDoc++;
						if(TFPerDocMap.containsKey(word))
							TFPerDocMap.put(word, TFPerDocMap.get(word)+1);
						else
							TFPerDocMap.put(word, 1.0);
					}
				}
				samBR.close();
				
				/*
				 * 遍历 TFPerDocMap,除以文档的总词数wordSumPerDoc 则得到TF
				 * TF*IDF得到最终的特征权值,并输出到文件
				 * 注意:测试样例和训练样例写入的文件不同
				 */
				if(j >= testBeginIndex && j <= testEndIndex)
					tsWriter = tsTestWriter;
				else
					tsWriter = tsTrainWriter;
				
				Double wordWeight;
				Set<Map.Entry<String, Double>> tempTF = TFPerDocMap.entrySet();
				for(Iterator<Map.Entry<String, Double>> mt = tempTF.iterator();mt.hasNext();){
					
					Map.Entry<String, Double> me = mt.next();
					
					//由于计算IDF非常耗时,3万多个词的属性词典初步估计需要25个小时,先尝试认为所有词的IDF都是1的情况
					//wordWeight = (me.getValue() / wordSumPerDoc) * iDFPerWordMap.get(me.getKey());
					wordWeight = (me.getValue() / wordSumPerDoc) * 1.0;
					TFPerDocMap.put(me.getKey(), wordWeight);
				}
				
				tsWriter.append(cateShortName + " ");
				tsWriter.append(fileShortName + " ");
				Set<Map.Entry<String, Double>> tempTF2 = TFPerDocMap.entrySet();
				for(Iterator<Map.Entry<String, Double>> mt = tempTF2.iterator();mt.hasNext();){
					Map.Entry<String, Double> me = mt.next();
					tsWriter.append(me.getKey() + " " + me.getValue()+" ");
				}
				tsWriter.append("\n");
				tsWriter.flush();
				
			}
		}
		tsTrainWriter.close();
		tsTestWriter.close();
		tsWriter.close();
	}
	
	/**
	 * 统计每个词的总出现次数,返回出现次数大于3词的词汇构成最终的属性词典
	 * @param strDir 处理好的newsgroup文件目录的绝对路径
	 * @param wordMap 记录出现的每个词构成的属性词典
	 * @return newWordMap 返回出现次数大于3次的词汇构成最终的属性词典
	 * @throws IOException
	 */
	public SortedMap<String, Double> countWords(String strDir,
			Map<String, Double> wordMap) throws IOException {
		
		File sampleFile = new File(strDir);
		File[] sample = sampleFile.listFiles();
		String word;
		
		for(int i =0 ;i < sample.length;i++){
			
			if(!sample[i].isDirectory()){
				FileReader samReader = new FileReader(sample[i]);
				BufferedReader samBR = new BufferedReader(samReader);
				while((word = samBR.readLine()) != null){
					if(!word.isEmpty() && wordMap.containsKey(word))
						wordMap.put(word, wordMap.get(word)+1);
					else
						wordMap.put(word, 1.0);
				}
				samBR.close();
			}else{
				countWords(sample[i].getCanonicalPath(),wordMap);
			}
		}
		
		/*
		 * 只返回出现次数大于3的单词
		 * 这里为了简单,应该独立一个函数,避免多次运行
		 */
		SortedMap<String,Double> newWordMap = new TreeMap<String, Double>();
		Set<Map.Entry<String, Double>> allWords = wordMap.entrySet();
		for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){
			Map.Entry<String, Double> me = it.next();
			if(me.getValue() > 2)
				newWordMap.put(me.getKey(), me.getValue());
		}
		
		System.out.println("newWordMap "+ newWordMap.size());
		
		return newWordMap;
	}
	
	/**
	 * 打印属性词典,到allDicWordCountMap.txt中
	 * @param wordMap 属性词典
	 * @throws IOException 
	 */
	public void printWordMap(Map<String, Double> wordMap) throws IOException{
		
		System.out.println("printWordMap:");
		int countLine = 0;
		File outPutFile = new File("E:/DataMiningSample/docVector/allDicWordCountMap.txt");
		FileWriter outPutFileWriter = new FileWriter(outPutFile);
		
		Set<Map.Entry<String, Double>> allWords = wordMap.entrySet();
		for(Iterator<Map.Entry<String, Double>> it = allWords.iterator();it.hasNext();){
			Map.Entry<String, Double> me = it.next();
			outPutFileWriter.write(me.getKey()+" "+me.getValue()+"\n");
			countLine++;
		}
		outPutFileWriter.close();
		System.out.println("WordMap size : " + countLine);
	}
	
	/**
	 * 词w在整个文档集合中的逆向文档频率idf (Inverse Document Frequency),
	 * 即文档总数n与词w所出现文件数docs(w, D)比值的对数: idf = log(n / docs(w, D))
	 * 计算IDF,即属性词典中每个词在多少个文档中出现过
	 * @param strDir 处理好的newsgroup文件目录的绝对路径
	 * @param wordMap 属性词典
	 * @return 单词的IDFMap
	 * @throws IOException 
	 */
	public SortedMap<String,Double> computeIDF(String strDir,Map<String, Double> wordMap) throws IOException{
		
		File fileDir = new File(strDir);
		String word;
		SortedMap<String,Double> IDFPerWordMap = new TreeMap<String, Double>();
		Set<Map.Entry<String, Double>> wordMapSet = wordMap.entrySet();
		
		for(Iterator<Map.Entry<String, Double>> it = wordMapSet.iterator();it.hasNext();){
			Map.Entry<String, Double> pe = it.next();
			Double countDoc = 0.0; //出现字典词的文本数
			Double sumDoc = 0.0; //文本总数
			String dicWord = pe.getKey();
			File[] sampleDir = fileDir.listFiles();
			
			for(int i =0;i<sampleDir.length;i++){
				
				File[] sample = sampleDir[i].listFiles();
				for(int j = 0;j<sample.length;j++){
					
					sumDoc++; //统计文本数
					
					FileReader samReader = new FileReader(sample[j]);
					BufferedReader samBR = new BufferedReader(samReader);
					boolean isExist = false;
					while((word = samBR.readLine()) != null){
						if(!word.isEmpty() && word.equals(dicWord)){
							isExist = true;
							break;
						}
					}
					if(isExist)
						countDoc++;
					
					samBR.close();
				}
			}
			//计算单词的IDF
			//double IDF = Math.log(sumDoc / countDoc) / Math.log(10);
			double IDF = Math.log(sumDoc / countDoc);
			IDFPerWordMap.put(dicWord, IDF);
		}
		return IDFPerWordMap;
	}
	
	
	
	public static void main(String[] args) throws IOException {
		
		ComputeWordsVector wordsVector = new ComputeWordsVector();
		
		String strDir = "E:\\DataMiningSample\\processedSample";
		Map<String, Double> wordMap = new TreeMap<String, Double>();
		
		//属性词典
		Map<String, Double> newWordMap = new TreeMap<String, Double>();
		
		newWordMap = wordsVector.countWords(strDir,wordMap);
		
		//wordsVector.printWordMap(newWordMap);
		//wordsVector.computeIDF(strDir, newWordMap);
		
		double trainSamplePercent = 0.8;
		int indexOfSample = 1;
		Map<String, Double> iDFPerWordMap = null;
		
		wordsVector.computeTFMultiIDF(strDir, trainSamplePercent, indexOfSample, iDFPerWordMap, newWordMap);
		
		//test();
	}
	
	public static void test(){
		
		double sumDoc  = 18828.0;
		double countDoc = 229.0;
		
		double IDF1 = Math.log(sumDoc / countDoc) / Math.log(10);
		double IDF2 = Math.log(sumDoc / countDoc) ;
		
		System.out.println(IDF1);
		System.out.println(IDF2);
		
		System.out.println(Math.log(10));
	}
	
}

5、对文本分为测试集和训练集

按指定的比例(0.9或者0.8)对整个文本进行划分,测试集和训练集

package com.datamine.NaiveBayes;

import java.io.*;
import java.util.*;


public class CreateTrainAndTestSample {

	
	void filterSpecialWords() throws IOException{
		
		String word;
		ComputeWordsVector cwv = new ComputeWordsVector();
		String fileDir = "E:\\DataMiningSample\\processedSample";
		SortedMap<String, Double> wordMap = new TreeMap<String, Double>();
		
		wordMap = cwv.countWords(fileDir, wordMap);
		cwv.printWordMap(wordMap); //把wordMap输出到文件
		
		File[] sampleDir = new File(fileDir).listFiles();
		for(int i = 0;i<sampleDir.length;i++){
			
			File[] sample = sampleDir[i].listFiles();
			String targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName();
			File targetDirFile = new File(targetDir);
			if(!targetDirFile.exists()){
				targetDirFile.mkdir();
			}
			
			for(int j = 0; j<sample.length;j++){
				
				String fileShortName = sample[j].getName();
				targetDir = "E:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName()+"/"+fileShortName;
				FileWriter tgWriter = new FileWriter(targetDir);
				FileReader samReader = new FileReader(sample[j]);
				BufferedReader samBR = new BufferedReader(samReader);
				while((word = samBR.readLine()) != null){
					if(wordMap.containsKey(word))
						tgWriter.append(word+"\n");
				}
				tgWriter.flush();
				tgWriter.close();
				samBR.close();
			}
		}
	}
	
	/**
	 * 创建训练集和测试集
	 * @param fileDir 预处理好的文件路径 E:\DataMiningSample\processedSampleOnlySpecial\
	 * @param trainSamplePercent 训练集占的百分比0.8
	 * @param indexOfSample 一个测试集计算规则  1
	 * @param classifyResultFile 测试样例正确类目记录文件
	 * @throws IOException
	 */
	void createTestSample(String fileDir,double trainSamplePercent,int indexOfSample,String classifyResultFile) throws IOException{
		
		String word,targetDir;
		FileWriter crWriter = new FileWriter(classifyResultFile);//测试样例正确类目记录文件
		File[] sampleDir = new File(fileDir).listFiles();
		
		for(int i =0;i<sampleDir.length;i++){
			
			File[] sample = sampleDir[i].listFiles();
			double testBeginIndex = indexOfSample*(sample.length*(1-trainSamplePercent));
			double testEndIndex = (indexOfSample + 1)*(sample.length*(1-trainSamplePercent));
			
			for(int j = 0;j<sample.length;j++){
				
				FileReader samReader = new FileReader(sample[j]);
				BufferedReader samBR = new BufferedReader(samReader);
				String fileShortName = sample[j].getName();
				String subFileName = fileShortName;
				
				if(j > testBeginIndex && j < testEndIndex){
					targetDir = "E:/DataMiningSample/TestSample"+indexOfSample+"/"+sampleDir[i].getName(); 
					crWriter.append(subFileName + " "+sampleDir[i].getName()+"\n");
				}else{
					targetDir = "E:/DataMiningSample/TrainSample"+indexOfSample+"/"+sampleDir[i].getName();
				}
					
				targetDir = targetDir.replace("\\", "/");
				File trainSamFile = new File(targetDir);
				if(!trainSamFile.exists()){
					trainSamFile.mkdir();
				}
				
				targetDir += "/" + subFileName;
				FileWriter tsWriter = new FileWriter(new File(targetDir));
				while((word = samBR.readLine()) != null)
					tsWriter.append(word+"\n");
				tsWriter.flush();
				
				tsWriter.close();
				samBR.close();
			}
			
		}
		crWriter.close();
	}
	
	
	public static void main(String[] args) throws IOException {
		
		CreateTrainAndTestSample test = new CreateTrainAndTestSample();
		
		String fileDir = "E:/DataMiningSample/processedSampleOnlySpecial";
		double trainSamplePercent=0.8;
		int indexOfSample=1;
		String classifyResultFile="E:/DataMiningSample/classifyResult";
		
		test.createTestSample(fileDir, trainSamplePercent, indexOfSample, classifyResultFile);
		//test.filterSpecialWords();
	}
	
	
}

6、朴素贝叶斯算法描述和实现

根据朴素贝叶斯公式,每个测试样例属于某个类别的概率 =  所有测试样例包含特征词类条件概率P(tk|c)之积 * 先验概率P(c)
在具体计算类条件概率和先验概率时,朴素贝叶斯分类器有两种模型:
(1)多元分布模型( multinomial model )  –以单词为粒度,也就是说,考虑每个文件里面重复出现多次的单词。注意多项分布其实是从二项分布拓展出来的,如果采用多项分布模型,那么每个单词表示变量就不再是二值变量(出现/不出现),而是每个单词在文件中出现的次数
类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+1)/(类c下单词总数+训练样本中不重复特征词总数)
先验概率P(c)=类c下的单词总数/整个训练样本的单词总数
(2)伯努利模型(Bernoulli model) –以文件为粒度,或者说是采用二项分布模型,伯努利实验即N次独立重复随机实验,只考虑事件发生/不发生,所以每个单词的表示变量是布尔型的
类条件概率P(tk|c)=(类c下包含单词tk的文件数+1)/(类c下文件总数+2)
先验概率P(c)=类c下文件总数/整个训练样本的文件总数
本分类器选用多元分布模型计算,根据《Introduction to Information Retrieval 》,多元分布模型计算准确率更高
贝叶斯算法的实现有以下注意点:
       (1) 计算概率用到了BigDecimal类实现任意精度计算
       (2) 用交叉验证法做十次分类实验,对准确率取平均值
       (3) 根据正确类目文件和分类结果文计算混淆矩阵并且输出
       (4) Map<String,Double> cateWordsProb key为“类目_单词”, value为该类目下该单词的出现次数,避免重复计算


package com.datamine.NaiveBayes;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.Vector;


/**
 * 利用朴素贝叶斯算法对newsgroup文档集做分类,采用十组交叉测试取平均值
 * 采用多项式模型
 * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集总单词数)
 * @author Administrator
 */
public class NaiveBayesianClassifier {

	/**
	 * 用朴素贝叶斯算法对测试文档集分类
	 * @param trainDir 训练文档集目录
	 * @param testDir  测试文档集目录
	 * @param classifyResultFileNew 分类结果文件路径
	 * @throws Exception 
	 */
	private void doProcess(String trainDir,String testDir,
			String classifyResultFileNew) throws Exception{
		
		//保存训练集中每个类别的总词数      <目录名,单词总数> category
		Map<String,Double> cateWordsNum = new TreeMap<String, Double>();
		//保存训练样本中每个类别中每个属性词的出现次数  <类目_单词,数目> 
		Map<String,Double> cateWordsProb = new TreeMap<String, Double>();
		
		cateWordsNum = getCateWordsNum(trainDir);
		cateWordsProb = getCateWordsProb(trainDir);
		
		double totalWordsNum = 0.0;//记录所有训练集的总词数
		Set<Map.Entry<String, Double>> cateWordsNumSet = cateWordsNum.entrySet();
		for(Iterator<Map.Entry<String, Double>> it = cateWordsNumSet.iterator();it.hasNext();){
			Map.Entry<String, Double> me = it.next();
			totalWordsNum += me.getValue();
		}
		
		//下面开始读取测试样例做分类
		Vector<String> testFileWords = new Vector<String>(); //测试样本所有词的容器
		String word;
		File[] testDirFiles = new File(testDir).listFiles();
		FileWriter crWriter = new FileWriter(classifyResultFileNew);
		for(int i =0;i<testDirFiles.length;i++){
			
			File[] testSample = testDirFiles[i].listFiles();
			
			for(int j =0;j<testSample.length;j++){
				
				testFileWords.clear();
				FileReader spReader = new FileReader(testSample[j]);
				BufferedReader spBR = new BufferedReader(spReader);
				while((word = spBR.readLine()) != null){
					testFileWords.add(word);
				}
				spBR.close();
				//下面分别计算该测试样例属于二十个类别的概率
				File[] trainDirFiles = new File(trainDir).listFiles();
				BigDecimal maxP = new BigDecimal(0);
				String bestCate = null;
				
				for(int k =0; k < trainDirFiles.length; k++){
					
					BigDecimal p = computeCateProb(trainDirFiles[k],testFileWords,cateWordsNum,totalWordsNum,cateWordsProb);
					
					if( k == 0){
						maxP = p;
						bestCate = trainDirFiles[k].getName();
						continue;
					}
					if(p.compareTo(maxP) == 1){
						maxP = p;
						bestCate = trainDirFiles[k].getName();
					}
				}
				crWriter.append(testSample[j].getName() + " " + bestCate + "\n");
				crWriter.flush();
			}
		}
		crWriter.close();
		
	}

	/**
	 * 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集中总单词数)
	 * 计算某一个测试样本数据某个类别的概率 使用多项式模型
	 * @param trainFile 该类别所有的训练样本所在的目录
	 * @param testFileWords  该测试样本中的所有词构成的容器
	 * @param cateWordsNum  记录每个目录下单词的总数
	 * @param totalWordsNum  所有训练样本的单词的总数
	 * @param cateWordsProb  记录每个目录中出现单词和次数
	 * @return 返回该测试样本在该类别中的概率
	 */
	private BigDecimal computeCateProb(File trainFile, Vector<String> testFileWords,
			Map<String, Double> cateWordsNum, double totalWordsNum, Map<String, Double> cateWordsProb) {
		
		BigDecimal probability = new BigDecimal(1);
		double wordNumInCate = cateWordsNum.get(trainFile.getName());
		BigDecimal wordNumInCateBD = new BigDecimal(wordNumInCate);
		BigDecimal totalWordsNumBD = new BigDecimal(totalWordsNum);
		
		for(Iterator<String> it = testFileWords.iterator();it.hasNext();){
			
			String me = it.next();
			String key = trainFile.getName()+"_"+me;
			double testFileWordNumInCate;
			if(cateWordsProb.containsKey(key))
				testFileWordNumInCate = cateWordsProb.get(key);
			else
				testFileWordNumInCate = 0.0;
			BigDecimal testFileWordNumInCateBD = new BigDecimal(testFileWordNumInCate);
			
			BigDecimal xcProb = (testFileWordNumInCateBD.add(new BigDecimal(0.0001)))
					.divide(wordNumInCateBD.add(totalWordsNumBD), 10, BigDecimal.ROUND_CEILING);
			probability = probability.multiply(xcProb);
		}
		// P =  P(tk|c)*P(C)
		BigDecimal result = probability.multiply(wordNumInCateBD.divide(totalWordsNumBD,10, BigDecimal.ROUND_CEILING));
		
		return result;
	}

	/**
	 * 统计某个类训练样本中每个单词出现的次数
	 * @param trainDir 训练样本集目录
	 * @return cateWordsProb 用"类目_单词"来索引map,value就是该类目下该单词出现的次数
	 * @throws Exception 
	 */
	private Map<String, Double> getCateWordsProb(String trainDir) throws Exception {

		Map<String,Double> cateWordsProb = new TreeMap<String, Double>();
		File sampleFile = new File(trainDir);
		File[] sampleDir = sampleFile.listFiles();
		String word;
		
		for(int i =0;i < sampleDir.length;i++){
			
			File[] sample = sampleDir[i].listFiles();
			
			for(int j =0;j<sample.length;j++){
				
				FileReader samReader = new FileReader(sample[j]);
				BufferedReader samBR = new BufferedReader(samReader);
				while((word = samBR.readLine()) != null){
					String key = sampleDir[i].getName()+"_"+word;
					if(cateWordsProb.containsKey(key))
						cateWordsProb.put(key, cateWordsProb.get(key)+1);
					else
						cateWordsProb.put(key, 1.0);
				}
				samBR.close();
			}
		}
		
		return cateWordsProb;
	}

	/**
	 * 获得每个类目下的单词总数
	 * @param trainDir 训练文档集目录
	 * @return cateWordsNum <目录名,单词总数>的map
	 * @throws IOException 
	 */
	private Map<String, Double> getCateWordsNum(String trainDir) throws IOException {

		Map<String, Double> cateWordsNum = new TreeMap<String, Double>();
		File[] sampleDir = new File(trainDir).listFiles();
		
		for(int i =0;i<sampleDir.length;i++){
			
			double count = 0;
			File[] sample = sampleDir[i].listFiles();
			
			for(int j =0;j<sample.length;j++){
				
				FileReader spReader = new FileReader(sample[j]);
				BufferedReader spBR = new BufferedReader(spReader);
				while(spBR.readLine() != null){
					count++;
				}
				spBR.close();
			}
			cateWordsNum.put(sampleDir[i].getName(), count);
		}
		
		return cateWordsNum;
	}
	
	/**
	 * 根据正确类目文件和分类结果文件统计出准确率
	 * @param classifyRightCate 正确类目文件     <文件名,类别目录名>
	 * @param classifyResultFileNew 分类结果文件     <文件名,类别目录名>
	 * @return 分类的准确率
	 * @throws Exception 
	 */
	public double computeAccuracy(String classifyRightCate,
			String classifyResultFileNew) throws Exception {
		
		Map<String,String> rightCate = new TreeMap<String, String>();
		Map<String,String> resultCate = new TreeMap<String,String>();
		rightCate = getMapFromResultFile(classifyRightCate);
		resultCate = getMapFromResultFile(classifyResultFileNew);
		
		Set<Map.Entry<String, String>> resCateSet = resultCate.entrySet();
		double rightCount = 0.0;
		for(Iterator<Map.Entry<String, String>> it = resCateSet.iterator();it.hasNext();){
			
			Map.Entry<String, String> me = it.next();
			if(me.getValue().equals(rightCate.get(me.getKey())))
				rightCount++;
		}
		
		computerConfusionMatrix(rightCate,resultCate);
		
		return rightCount / resultCate.size();
	}
	
	/**
	 * 根据正确类目文件和分类结果文件计算混淆矩阵并输出
	 * @param rightCate 正确类目map 
	 * @param resultCate 分类结果对应map
	 */
	private void computerConfusionMatrix(Map<String, String> rightCate,
			Map<String, String> resultCate) {
		
		int[][] confusionMatrix = new int[20][20];
		
		//首先求出类目对应的数组索引
		SortedSet<String> cateNames = new TreeSet<String>();
		Set<Map.Entry<String, String>> rightCateSet = rightCate.entrySet();
		for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){
			
			Map.Entry<String, String> me = it.next();
			cateNames.add(me.getValue());
		}
		cateNames.add("rec.sport.baseball");//防止数少一个类目
		String[] cateNamesArray = cateNames.toArray(new String[0]);
		Map<String,Integer> cateNamesToIndex = new TreeMap<String, Integer>();
		
		for(int i =0;i<cateNamesArray.length;i++){
			cateNamesToIndex.put(cateNamesArray[i], i);
		}
		
		for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator();it.hasNext();){
			
			Map.Entry<String, String> me = it.next();
			confusionMatrix[cateNamesToIndex.get(me.getValue())][cateNamesToIndex.get(resultCate.get(me.getKey()))]++;
		}
		
		//输出混淆矩阵
		double[] hangSum = new double[20];
		System.out.print("      ");
		
		for(int i=0;i<20;i++){
			System.out.printf("%-6d",i);
		}
		
		System.out.println("准确率");
		
		for(int i =0;i<20;i++){
			System.out.printf("%-6d",i);
			for(int j = 0;j<20;j++){
				System.out.printf("%-6d",confusionMatrix[i][j]);
				hangSum[i] += confusionMatrix[i][j];
			}
			System.out.printf("%-6f\n",confusionMatrix[i][i]/hangSum[i]);
		}
		System.out.println();
		
	}

	/**
	 * 从结果文件中读取Map
	 * @param file 类目文件
	 * @return Map<String,String> 由<文件名,类目名>保存的map
	 * @throws Exception 
	 */
	private Map<String, String> getMapFromResultFile(String file) throws Exception {

		File crFile = new File(file);
		FileReader crReader = new FileReader(crFile);
		BufferedReader crBR = new BufferedReader(crReader);
		Map<String,String> res = new TreeMap<String, String>();
		String[] s;
		String line;
		while((line = crBR.readLine()) != null){
			s = line.split(" ");
			res.put(s[0], s[1]);
		}
		return res;
	}

	public static void main(String[] args) throws Exception {
		
		CreateTrainAndTestSample ctt = new CreateTrainAndTestSample();
		NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier();
		
		//根据包含非特征词的文档集生成只包含特征词的文档集到processedSampleOnlySpecial目录下
		ctt.filterSpecialWords();
		
		double[] accuracyOfEveryExp  = new double[10];
		double accuracyAvg,sum = 0;
		

		for(int i =0;i<10;i++){//用交叉验证法做十次分类实验,对准确率取平均值	
			
			String TrainDir = "E:/DataMiningSample/TrainSample"+i;
			String TestDir = "E:/DataMiningSample/TestSample"+i;
			String classifyRightCate = "E:/DataMiningSample/classifyRightCate"+i+".txt";
			String classifyResultFileNew = "E:/DataMiningSample/classifyResultNew"+i+".txt";
		
			ctt.createTestSample("E:/DataMiningSample/processedSampleOnlySpecial", 0.8, i, classifyRightCate);
			
			nbClassifier.doProcess(TrainDir, TestDir, classifyResultFileNew);
		
			accuracyOfEveryExp[i] = nbClassifier.computeAccuracy(classifyRightCate,classifyResultFileNew);
			
			System.out.println("The accuracy for Naive Bayesian Classifier in "+i+"th Exp is :" + accuracyOfEveryExp[i]);
			
		}
		
		for(int i =0;i<10;i++)
			sum += accuracyOfEveryExp[i];
		accuracyAvg = sum/10;
		
		System.out.println("The average accuracy for Naive Bayesian Classifier in all Exps is :" + accuracyAvg);
		
	}

	
	
}

7、实验结果与说明

结果(略)

这里使用的多项式模型是经过改进的计算方法:改进多项式模型的类条件概率的计算公式,改进为 类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+0.001)/(类c下单词总数+训练样本中不重复特征词总数),分子当tk没有出现时,只加0.001,这样更加精确的描述的词的统计分布规律

8、算法改进

为了进一步提高朴素贝叶斯算法的分类可以进行如下改进:

1、优化特征词选取的方法,如方法三,或者更优方法

2、改进多项式模型的类条件概率的计算公式(上面已经实现)


你可能感兴趣的:(文本分类——NaiveBayes)