前面文章已经介绍了朴素贝叶斯算法的原理,这里基于NavieBayes算法对newsgroup文本进行分类测试。
文中代码参考:http://blog.csdn.net/jiangliqing1234/article/details/39642757
主要内容如下:
数据下载地址:http://download.csdn.net/detail/hjy321686/8057761。 文本中包含20个不同的新闻组,除其中少数文本属于多个新闻组以外,其余的文档都只属于一个新闻组。
要对文本进行分类,首先要对其进行预处理,预处理主要过程如下:
step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,所有大写字母转换成小写,可用正则表达式:String res[] = line.split("[^a-zA-Z]");
step2:去停用词,过滤对别无价值的词
step3:词根还原stemmer,基于Porter算法
预处理类如下:
package com.datamine.NaiveBayes;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.ArrayList;
/**
* Newsgroup文档预处理
* step1:英文词法分析,取出数字、连字符、标点符号、特殊字符,所有大写字母转换成小写,可用正则表达式:String res[] = line.split("[^a-zA-Z]");
* step2:去停用词,过滤对分类无价值的词
* step3:词根还原stemmer,基于Porter算法
* @author Administrator
*
*/
public class DataPreProcess {
private static ArrayList stopWordsArray = new ArrayList();
/**
* 输入文件的路径,处理数据
* @param srcDir 文件目录的绝对路径
* @param desDir 清洗后的文件路径
* @throws Exception
*/
public void doProcess(String srcDir) throws Exception{
File fileDir = new File(srcDir);
if(!fileDir.exists()){
System.out.println("文件不存在!");
return ;
}
String subStrDir = srcDir.substring(srcDir.lastIndexOf('/'));
String dirTarget = srcDir+"/../../processedSample"+subStrDir;
File fileTarget = new File(dirTarget);
if(!fileTarget.exists()){
//注意processedSample需要先建立目录建出来,否则会报错,因为母目录不存在
boolean mkdir = fileTarget.mkdir();
}
File[] srcFiles = fileDir.listFiles();
for(int i =0 ;i
/**
* Stemmer中接口,将传入的word进行词根还原
* @param word 传入单词
* @return result 处理后的单词
*/
public static String run(String word){
Stemmer s = new Stemmer();
char[] ch = word.toCharArray();
for (int c = 0; c < ch.length; c++)
s.add(ch[c]);
s.stem();
{
String u;
u = s.toString();
//System.out.print(u);
return u;
}
}
方法一:保留所有词作为特征词
方法二:选取出现频率大于某一个数(3或者其他)的词作为特征词
方法三:计算每个词的权重tf*idf,根据权重来选取特征词
本文中选取方法二。
由于本文中,特征词选择采用的是方法二,可以不用对文本进行向量化,但是统计特征词出现的次数方法写在ComputeWordsVector类中,为了程序运行这里还是把文本向量化的代码贴出来。后面使用KNN算法的时候也是要用到此类的。
package com.datamine.NaiveBayes;
import java.io.*;
import java.util.*;
/**
* 计算文档的属性向量,将所有文档向量化
* @author Administrator
*/
public class ComputeWordsVector {
/**
* 计算文档的TF属性向量,TFPerDocMap
* 计算TF*IDF
* @param strDir 处理好的newsgroup文件目录的绝对路径
* @param trainSamplePercent 训练样本集占每个类目的比例
* @param indexOfSample 测试样例集的起始的测试样例编号 注释:通过这个参数可以将文本分成训练和测试两部分
* @param iDFPerWordMap 每个词的IDF权值属性向量
* @param wordMap 属性词典map
* @throws IOException
*/
public void computeTFMultiIDF(String strDir,double trainSamplePercent,int indexOfSample,
Map iDFPerWordMap,Map wordMap) throws IOException{
File fileDir = new File(strDir);
String word;
SortedMap TFPerDocMap = new TreeMap();
//注意可以用两个写文件,一个专门写测试样例,一个专门写训练样例,用sampleType的值来表示
String trainFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTrainSample"+indexOfSample;
String testFileDir = "E:/DataMiningSample/docVector/wordTFIDFMapTestSample"+indexOfSample;
FileWriter tsTrainWriter = new FileWriter(new File(trainFileDir)); //往训练文件中写
FileWriter tsTestWriter = new FileWriter(new File(testFileDir)); //往测试文件中写
FileWriter tsWriter = null;
File[] sampleDir = fileDir.listFiles();
for(int i = 0;i= testBeginIndex && j <= testEndIndex)
tsWriter = tsTestWriter;
else
tsWriter = tsTrainWriter;
Double wordWeight;
Set> tempTF = TFPerDocMap.entrySet();
for(Iterator> mt = tempTF.iterator();mt.hasNext();){
Map.Entry me = mt.next();
//由于计算IDF非常耗时,3万多个词的属性词典初步估计需要25个小时,先尝试认为所有词的IDF都是1的情况
//wordWeight = (me.getValue() / wordSumPerDoc) * iDFPerWordMap.get(me.getKey());
wordWeight = (me.getValue() / wordSumPerDoc) * 1.0;
TFPerDocMap.put(me.getKey(), wordWeight);
}
tsWriter.append(cateShortName + " ");
tsWriter.append(fileShortName + " ");
Set> tempTF2 = TFPerDocMap.entrySet();
for(Iterator> mt = tempTF2.iterator();mt.hasNext();){
Map.Entry me = mt.next();
tsWriter.append(me.getKey() + " " + me.getValue()+" ");
}
tsWriter.append("\n");
tsWriter.flush();
}
}
tsTrainWriter.close();
tsTestWriter.close();
tsWriter.close();
}
/**
* 统计每个词的总出现次数,返回出现次数大于3词的词汇构成最终的属性词典
* @param strDir 处理好的newsgroup文件目录的绝对路径
* @param wordMap 记录出现的每个词构成的属性词典
* @return newWordMap 返回出现次数大于3次的词汇构成最终的属性词典
* @throws IOException
*/
public SortedMap countWords(String strDir,
Map wordMap) throws IOException {
File sampleFile = new File(strDir);
File[] sample = sampleFile.listFiles();
String word;
for(int i =0 ;i < sample.length;i++){
if(!sample[i].isDirectory()){
FileReader samReader = new FileReader(sample[i]);
BufferedReader samBR = new BufferedReader(samReader);
while((word = samBR.readLine()) != null){
if(!word.isEmpty() && wordMap.containsKey(word))
wordMap.put(word, wordMap.get(word)+1);
else
wordMap.put(word, 1.0);
}
samBR.close();
}else{
countWords(sample[i].getCanonicalPath(),wordMap);
}
}
/*
* 只返回出现次数大于3的单词
* 这里为了简单,应该独立一个函数,避免多次运行
*/
SortedMap newWordMap = new TreeMap();
Set> allWords = wordMap.entrySet();
for(Iterator> it = allWords.iterator();it.hasNext();){
Map.Entry me = it.next();
if(me.getValue() > 2)
newWordMap.put(me.getKey(), me.getValue());
}
System.out.println("newWordMap "+ newWordMap.size());
return newWordMap;
}
/**
* 打印属性词典,到allDicWordCountMap.txt中
* @param wordMap 属性词典
* @throws IOException
*/
public void printWordMap(Map wordMap) throws IOException{
System.out.println("printWordMap:");
int countLine = 0;
File outPutFile = new File("E:/DataMiningSample/docVector/allDicWordCountMap.txt");
FileWriter outPutFileWriter = new FileWriter(outPutFile);
Set> allWords = wordMap.entrySet();
for(Iterator> it = allWords.iterator();it.hasNext();){
Map.Entry me = it.next();
outPutFileWriter.write(me.getKey()+" "+me.getValue()+"\n");
countLine++;
}
outPutFileWriter.close();
System.out.println("WordMap size : " + countLine);
}
/**
* 词w在整个文档集合中的逆向文档频率idf (Inverse Document Frequency),
* 即文档总数n与词w所出现文件数docs(w, D)比值的对数: idf = log(n / docs(w, D))
* 计算IDF,即属性词典中每个词在多少个文档中出现过
* @param strDir 处理好的newsgroup文件目录的绝对路径
* @param wordMap 属性词典
* @return 单词的IDFMap
* @throws IOException
*/
public SortedMap computeIDF(String strDir,Map wordMap) throws IOException{
File fileDir = new File(strDir);
String word;
SortedMap IDFPerWordMap = new TreeMap();
Set> wordMapSet = wordMap.entrySet();
for(Iterator> it = wordMapSet.iterator();it.hasNext();){
Map.Entry pe = it.next();
Double countDoc = 0.0; //出现字典词的文本数
Double sumDoc = 0.0; //文本总数
String dicWord = pe.getKey();
File[] sampleDir = fileDir.listFiles();
for(int i =0;i wordMap = new TreeMap();
//属性词典
Map newWordMap = new TreeMap();
newWordMap = wordsVector.countWords(strDir,wordMap);
//wordsVector.printWordMap(newWordMap);
//wordsVector.computeIDF(strDir, newWordMap);
double trainSamplePercent = 0.8;
int indexOfSample = 1;
Map iDFPerWordMap = null;
wordsVector.computeTFMultiIDF(strDir, trainSamplePercent, indexOfSample, iDFPerWordMap, newWordMap);
//test();
}
public static void test(){
double sumDoc = 18828.0;
double countDoc = 229.0;
double IDF1 = Math.log(sumDoc / countDoc) / Math.log(10);
double IDF2 = Math.log(sumDoc / countDoc) ;
System.out.println(IDF1);
System.out.println(IDF2);
System.out.println(Math.log(10));
}
}
按指定的比例(0.9或者0.8)对整个文本进行划分,测试集和训练集
package com.datamine.NaiveBayes;
import java.io.*;
import java.util.*;
public class CreateTrainAndTestSample {
void filterSpecialWords() throws IOException{
String word;
ComputeWordsVector cwv = new ComputeWordsVector();
String fileDir = "E:\\DataMiningSample\\processedSample";
SortedMap wordMap = new TreeMap();
wordMap = cwv.countWords(fileDir, wordMap);
cwv.printWordMap(wordMap); //把wordMap输出到文件
File[] sampleDir = new File(fileDir).listFiles();
for(int i = 0;i testBeginIndex && j < testEndIndex){
targetDir = "E:/DataMiningSample/TestSample"+indexOfSample+"/"+sampleDir[i].getName();
crWriter.append(subFileName + " "+sampleDir[i].getName()+"\n");
}else{
targetDir = "E:/DataMiningSample/TrainSample"+indexOfSample+"/"+sampleDir[i].getName();
}
targetDir = targetDir.replace("\\", "/");
File trainSamFile = new File(targetDir);
if(!trainSamFile.exists()){
trainSamFile.mkdir();
}
targetDir += "/" + subFileName;
FileWriter tsWriter = new FileWriter(new File(targetDir));
while((word = samBR.readLine()) != null)
tsWriter.append(word+"\n");
tsWriter.flush();
tsWriter.close();
samBR.close();
}
}
crWriter.close();
}
public static void main(String[] args) throws IOException {
CreateTrainAndTestSample test = new CreateTrainAndTestSample();
String fileDir = "E:/DataMiningSample/processedSampleOnlySpecial";
double trainSamplePercent=0.8;
int indexOfSample=1;
String classifyResultFile="E:/DataMiningSample/classifyResult";
test.createTestSample(fileDir, trainSamplePercent, indexOfSample, classifyResultFile);
//test.filterSpecialWords();
}
}
package com.datamine.NaiveBayes;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.Vector;
/**
* 利用朴素贝叶斯算法对newsgroup文档集做分类,采用十组交叉测试取平均值
* 采用多项式模型
* 类条件概率 P(tk|c)=(类c下 单词tk 在各个文档中出现过的次数之和 + 1)/(类c下单词的总数 + 训练集总单词数)
* @author Administrator
*/
public class NaiveBayesianClassifier {
/**
* 用朴素贝叶斯算法对测试文档集分类
* @param trainDir 训练文档集目录
* @param testDir 测试文档集目录
* @param classifyResultFileNew 分类结果文件路径
* @throws Exception
*/
private void doProcess(String trainDir,String testDir,
String classifyResultFileNew) throws Exception{
//保存训练集中每个类别的总词数 <目录名,单词总数> category
Map cateWordsNum = new TreeMap();
//保存训练样本中每个类别中每个属性词的出现次数 <类目_单词,数目>
Map cateWordsProb = new TreeMap();
cateWordsNum = getCateWordsNum(trainDir);
cateWordsProb = getCateWordsProb(trainDir);
double totalWordsNum = 0.0;//记录所有训练集的总词数
Set> cateWordsNumSet = cateWordsNum.entrySet();
for(Iterator> it = cateWordsNumSet.iterator();it.hasNext();){
Map.Entry me = it.next();
totalWordsNum += me.getValue();
}
//下面开始读取测试样例做分类
Vector testFileWords = new Vector(); //测试样本所有词的容器
String word;
File[] testDirFiles = new File(testDir).listFiles();
FileWriter crWriter = new FileWriter(classifyResultFileNew);
for(int i =0;i testFileWords,
Map cateWordsNum, double totalWordsNum, Map cateWordsProb) {
BigDecimal probability = new BigDecimal(1);
double wordNumInCate = cateWordsNum.get(trainFile.getName());
BigDecimal wordNumInCateBD = new BigDecimal(wordNumInCate);
BigDecimal totalWordsNumBD = new BigDecimal(totalWordsNum);
for(Iterator it = testFileWords.iterator();it.hasNext();){
String me = it.next();
String key = trainFile.getName()+"_"+me;
double testFileWordNumInCate;
if(cateWordsProb.containsKey(key))
testFileWordNumInCate = cateWordsProb.get(key);
else
testFileWordNumInCate = 0.0;
BigDecimal testFileWordNumInCateBD = new BigDecimal(testFileWordNumInCate);
BigDecimal xcProb = (testFileWordNumInCateBD.add(new BigDecimal(0.0001)))
.divide(wordNumInCateBD.add(totalWordsNumBD), 10, BigDecimal.ROUND_CEILING);
probability = probability.multiply(xcProb);
}
// P = P(tk|c)*P(C)
BigDecimal result = probability.multiply(wordNumInCateBD.divide(totalWordsNumBD,10, BigDecimal.ROUND_CEILING));
return result;
}
/**
* 统计某个类训练样本中每个单词出现的次数
* @param trainDir 训练样本集目录
* @return cateWordsProb 用"类目_单词"来索引map,value就是该类目下该单词出现的次数
* @throws Exception
*/
private Map getCateWordsProb(String trainDir) throws Exception {
Map cateWordsProb = new TreeMap();
File sampleFile = new File(trainDir);
File[] sampleDir = sampleFile.listFiles();
String word;
for(int i =0;i < sampleDir.length;i++){
File[] sample = sampleDir[i].listFiles();
for(int j =0;j的map
* @throws IOException
*/
private Map getCateWordsNum(String trainDir) throws IOException {
Map cateWordsNum = new TreeMap();
File[] sampleDir = new File(trainDir).listFiles();
for(int i =0;i
* @param classifyResultFileNew 分类结果文件 <文件名,类别目录名>
* @return 分类的准确率
* @throws Exception
*/
public double computeAccuracy(String classifyRightCate,
String classifyResultFileNew) throws Exception {
Map rightCate = new TreeMap();
Map resultCate = new TreeMap();
rightCate = getMapFromResultFile(classifyRightCate);
resultCate = getMapFromResultFile(classifyResultFileNew);
Set> resCateSet = resultCate.entrySet();
double rightCount = 0.0;
for(Iterator> it = resCateSet.iterator();it.hasNext();){
Map.Entry me = it.next();
if(me.getValue().equals(rightCate.get(me.getKey())))
rightCount++;
}
computerConfusionMatrix(rightCate,resultCate);
return rightCount / resultCate.size();
}
/**
* 根据正确类目文件和分类结果文件计算混淆矩阵并输出
* @param rightCate 正确类目map
* @param resultCate 分类结果对应map
*/
private void computerConfusionMatrix(Map rightCate,
Map resultCate) {
int[][] confusionMatrix = new int[20][20];
//首先求出类目对应的数组索引
SortedSet cateNames = new TreeSet();
Set> rightCateSet = rightCate.entrySet();
for(Iterator> it = rightCateSet.iterator();it.hasNext();){
Map.Entry me = it.next();
cateNames.add(me.getValue());
}
cateNames.add("rec.sport.baseball");//防止数少一个类目
String[] cateNamesArray = cateNames.toArray(new String[0]);
Map cateNamesToIndex = new TreeMap();
for(int i =0;i> it = rightCateSet.iterator();it.hasNext();){
Map.Entry me = it.next();
confusionMatrix[cateNamesToIndex.get(me.getValue())][cateNamesToIndex.get(resultCate.get(me.getKey()))]++;
}
//输出混淆矩阵
double[] hangSum = new double[20];
System.out.print(" ");
for(int i=0;i<20;i++){
System.out.printf("%-6d",i);
}
System.out.println("准确率");
for(int i =0;i<20;i++){
System.out.printf("%-6d",i);
for(int j = 0;j<20;j++){
System.out.printf("%-6d",confusionMatrix[i][j]);
hangSum[i] += confusionMatrix[i][j];
}
System.out.printf("%-6f\n",confusionMatrix[i][i]/hangSum[i]);
}
System.out.println();
}
/**
* 从结果文件中读取Map
* @param file 类目文件
* @return Map 由<文件名,类目名>保存的map
* @throws Exception
*/
private Map getMapFromResultFile(String file) throws Exception {
File crFile = new File(file);
FileReader crReader = new FileReader(crFile);
BufferedReader crBR = new BufferedReader(crReader);
Map res = new TreeMap();
String[] s;
String line;
while((line = crBR.readLine()) != null){
s = line.split(" ");
res.put(s[0], s[1]);
}
return res;
}
public static void main(String[] args) throws Exception {
CreateTrainAndTestSample ctt = new CreateTrainAndTestSample();
NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier();
//根据包含非特征词的文档集生成只包含特征词的文档集到processedSampleOnlySpecial目录下
ctt.filterSpecialWords();
double[] accuracyOfEveryExp = new double[10];
double accuracyAvg,sum = 0;
for(int i =0;i<10;i++){//用交叉验证法做十次分类实验,对准确率取平均值
String TrainDir = "E:/DataMiningSample/TrainSample"+i;
String TestDir = "E:/DataMiningSample/TestSample"+i;
String classifyRightCate = "E:/DataMiningSample/classifyRightCate"+i+".txt";
String classifyResultFileNew = "E:/DataMiningSample/classifyResultNew"+i+".txt";
ctt.createTestSample("E:/DataMiningSample/processedSampleOnlySpecial", 0.8, i, classifyRightCate);
nbClassifier.doProcess(TrainDir, TestDir, classifyResultFileNew);
accuracyOfEveryExp[i] = nbClassifier.computeAccuracy(classifyRightCate,classifyResultFileNew);
System.out.println("The accuracy for Naive Bayesian Classifier in "+i+"th Exp is :" + accuracyOfEveryExp[i]);
}
for(int i =0;i<10;i++)
sum += accuracyOfEveryExp[i];
accuracyAvg = sum/10;
System.out.println("The average accuracy for Naive Bayesian Classifier in all Exps is :" + accuracyAvg);
}
}
这里只列出第一次执行的结果:
这里使用的多项式模型是经过改进的计算方法:改进多项式模型的类条件概率的计算公式,改进为 类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+0.001)/(类c下单词总数+训练样本中不重复特征词总数),分子当tk没有出现时,只加0.001,这样更加精确的描述的词的统计分布规律
为了进一步提高朴素贝叶斯算法的分类可以进行如下改进:
1、优化特征词选取的方法,如方法三,或者更优方法
2、改进多项式模型的类条件概率的计算公式(上面已经实现)