1、用MapReduce算法实现贝叶斯分类器的训练过程,并输出训练模型;
2、用输出的模型对测试集文档进行分类测试。测试过程可基于单机Java程序,也可以是MapReduce程序。输出每个测试文档的分类结果;
3、利用测试文档的真实类别,计算分类模型的Precision,Recall和F1值。
2.实验环境
实验平台:VMware Workstation10
虚拟机系统:Suse11
集群环境:主机名master ip:192.168.226.129
从机名slave1 ip:192.168.226.130
贝叶斯分类器的分类原理是通过某对象的先验概率,利用贝叶斯公式计算出其后验概率,即该对象属于某一类的概率,选择具有最大后验概率的类作为该对象所属的类。
应用贝叶斯分类器进行分类主要分成两阶段。第一阶段是贝叶斯统计分类器的学习阶段,即根据训练数据集训练得出训练模型;第二阶段是贝叶斯分类器的推理阶段,即根据训练模型计算属于各个分类的概率,进行分类。
贝叶斯公式如下:
其中A、B分别为两个不同的事件,P(A)是A的先验概率,P(A|B)是已知B发生后A的条件概率,也由于得自B的取值而被称作A的后验概率。而上式就是用事件B的先验概率来求它的后验概率。
3.1贝叶斯文本分类流程图
3.2贝叶斯文本分类详细步骤
整个文档归类过程可以分为以下步骤:
3.3具体算法设计
一个Country有多个news.txt, 一个news 有多个word
我们所设计的算法最后是要得到随机抽取一个txt文档,它最有可能属于哪个国家类别,也就是我们要得到它属于哪个国家的概率最大,把它转化为数学公式也就是:
3-1
为了便于比较,我们将上式取对数得到:
3-2
其中Num(Wi)表示该txt文档中单词Wi的个数;P(C|Wi) 表示拿出一个单词Wi,其属于国家C的后验概率。根据贝叶斯公式有:P(C|W) = P(W|C)*P(C)/P(W),其中:
P(W|C):国家C的news中单词W出现的概率,根据式3-2,不能使该概率为0,所以我们约定每个国家都至少包含每个单词一次,也就是在统计单词数量时,都自动的加1,就有:
3-3
P(C):国家C出现的概率(正比于其所含txt文件数);
P(W):单词W在整个测试集中出现的概率。
根据上面的贝叶斯公式我们设计的MapReduce算法如下:
合并<
由得到国家C中含有的单词总数,记为N(C);
由得到测试集中单词W的总数,记为N(W);
再由得到测试集的单词总数,记为N。
则可求得P(W|C) = N(C,W)/N(C);P(C) = N(C)/N;P(W) = N(W)/N。
3.4MapReduce的Data Flow示意图
本实验中的主要代码如下所示
4.1 SmallFilesToSequenceFileConverter.java 小文件集合打包工具类MapReduce程序
4.2 WholeFileInputFormat.java 支持类:递归读取指定目录下的所有文件
4.3 WholeFileRecordReader.java 支持类:读取单个文件的全部内容
4.4 DocCount.java 文档统计MapReduce程序
4.5 WordCount.java 单词统计MapReduce程序
4.6 DocClassification.java 测试文档分类MapReduce程序
详细代码如下:
4.1 SmallFilesToSequenceFileConverter.java 其中Map,Reduce关键代码如下:
publicclass SmallFilesToSequenceFileConverter extends Configured implements Tool {
staticclass SequenceFileMapper extends Mapper
private String fileNameKey; // 被打包的小文件名作为key,表示为Text对象
private String classNameKey; // 当前文档所在的分类名
@Override// 重新实现setup方法,进行map任务的初始化设置
protectedvoid setup(Context context) throws IOException, InterruptedException {
InputSplit split = context.getInputSplit(); // 从context获取split
Path path = ((FileSplit) split).getPath(); // 从split获取文件路径
fileNameKey = path.getName(); // 将文件路径实例化为key对象
classNameKey = path.getParent().getName();
}
@Override// 实现map方法
protectedvoid map(NullWritable key, BytesWritable value, Context context)
throws IOException, InterruptedException {
// 注意sequencefile的key和value (key:分类,文档名 value:文档的内容)
context.write(new Text(classNameKey + "/" + fileNameKey), value);
}
}
}
4.2 WholeFileInputFormat.java 其中关键代码如下:
publicclass WholeFileInputFormat extends FileInputFormat
/**
* 方法描述:递归遍历输入目录下的所有文件
* 备注:该写FileInputFormat,使支持多层目录的输入
* @authormeify DateTime 2015年11月3日下午2:37:49
* @param fs
* @param path
*/
void search(FileSystem fs, Path path) {
try {
if (fs.isFile(path)) {
fileStatus.add(fs.getFileStatus(path));
} elseif (fs.isDirectory(path)) {
FileStatus[] fileStatus = fs.listStatus(path);
for (inti = 0; i < fileStatus.length; i++) {
FileStatus fileStatu = fileStatus[i];
search(fs, fileStatu.getPath());
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public RecordReader
throws IOException, InterruptedException {
WholeFileRecordReader reader = new WholeFileRecordReader();
reader.initialize(split, context);
returnreader;
}
@Override
protected List
FileSystem fs = FileSystem.get(job.getConfiguration());
// 输入根目录
String rootDir = job.getConfiguration().get("mapred.input.dir", "");
// 递归获取输入目录下的所有文件
search(fs, new Path(rootDir));
returnthis.fileStatus;
}
}
4.3 WholeFileRecordReader.java 其中关键代码如下:
publicclass WholeFileRecordReader extends RecordReader
private FileSplit fileSplit; //保存输入的分片,它将被转换成一条( key, value)记录
private Configuration conf; //配置对象
private BytesWritable value = new BytesWritable(); //value对象,内容为空
privatebooleanprocessed = false; //布尔变量记录记录是否被处理过
@Override
publicboolean nextKeyValue() throws IOException, InterruptedException {
if (!processed) { //如果记录没有被处理过
//从fileSplit对象获取split的字节数,创建byte数组contents
byte[] contents = newbyte[(int) fileSplit.getLength()];
Path file = fileSplit.getPath(); //从fileSplit对象获取输入文件路径
FileSystem fs = file.getFileSystem(conf); //获取文件系统对象
FSDataInputStream in = null; //定义文件输入流对象
try {
in = fs.open(file); //打开文件,返回文件输入流对象
//从输入流读取所有字节到contents
IOUtils.readFully(in, contents, 0, contents.length); value.set(contents, 0, contents.length); //将contens内容设置到value对象中
} finally {
IOUtils.closeStream(in); //关闭输入流
}
processed = true; //将是否处理标志设为true,下次调用该方法会返回false
returntrue;
}
returnfalse; //如果记录处理过,返回false,表示split处理完毕
}
}
4.4 DocCount.java 其中Map,Reduce关键代码如下:
publicclass DocCount extends Configured implements Tool{
publicstaticclass Map extends Mapper
@Override
publicvoid map(Text key, BytesWritable value, Context context) {
try {
String currentKey = key.toString();
String[] arr = currentKey.split("/");
String className = arr[0];
String fileName = arr[1];
System.out.println(className + "," + fileName);
context.write(new Text(className), new IntWritable(1));
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
publicstaticclass Reduce extends Reducer
private IntWritable result = new IntWritable();
publicvoid reduce(Text key, Iterable
intsum = 0;
for (IntWritable val : values) {
sum ++;
}
result.set(sum);
context.write(key, result); // 输出结果key: 分类 , value: 文档个数
}
}
}
4.5 WordCount.java 其中Map,Reduce关键代码如下:
publicclass WordCount extends Configured implements Tool{
publicstaticclass Map extends Mapper
@Override
publicvoid map(Text key, BytesWritable value, Context context) {
try {
String[] arr = key.toString().split("/");
String className = arr[0];
String fileName = arr[1];
value.setCapacity(value.getSize()); // 剔除多余空间
// 文本内容
String content = new String(value.getBytes(), 0, value.getLength());
StringTokenizer itr = new StringTokenizer(content);
while (itr.hasMoreTokens()) {
String word = itr.nextToken();
if(StringUtil.isValidWord(word))
{
System.out.println(className + "/" + word);
context.write(new Text(className + "/" + word), new IntWritable(1));
}
}
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
publicstaticclass Reduce extends Reducer
private IntWritable result = new IntWritable();
publicvoid reduce(Text key, Iterable
intsum = 1; // 注意这里单词的个数从1开始计数
for (IntWritable val : values) {
sum ++;
}
result.set(sum);
context.write(key, result); // 输出结果key: 分类/ 单词 , value: 频次
}
}
}
4.6 DocClassification.java 其中Map,Reduce关键代码如下:
publicclass DocClassification extends Configured implements Tool {
// 所有分类集合
privatestatic List
// 所有分类的先验概率(其中的概率取对数log)
privatestatic HashMap
// 所有单词在各个分类中的出现的频次
privatestatic HashMap
// 分类下的所有单词出现的总频次
privatestatic HashMap
privatestatic Configuration conf = new Configuration();
static {
// 初始化分类先验概率词典
initClassProMap("hdfs://192.168.226.129:9000/user/hadoop/doc");
// 初始化单词在各个分类中的条件概率词典
initClassWordProMap("hdfs://192.168.226.129:9000/user/hadoop/word");
}
publicstaticclass Map extends Mapper
@Override
publicvoid map(Text key, BytesWritable value, Context context) {
String fileName = key.toString();
value.setCapacity(value.getSize()); // 剔除多余空间
String content = new String(value.getBytes(), 0, value.getLength());
try {
for (String className : classList) {
doubleresult = Math.log(classProMap.get(className));
StringTokenizer itr = new StringTokenizer(content);
while (itr.hasMoreTokens()) {
String word = itr.nextToken();
if (StringUtil.isValidWord(word)) {
intwordSum = 1;
if(classWordNumMap.get(className + "/" + word) != null){
wordSum = classWordNumMap.get(className + "/" + word);
}
intclassWordSum = classWordSumMap.get(className);
doublepro_class_word = Math.log(((double)wordSum)/classWordSum);
result += pro_class_word;
}
}
// 输出的形式 key:文件名 value:分类名/概率
context.write(new Text(fileName), new Text(className + "/" + String.valueOf(result)));
}
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
publicstaticclass Reduce extends Reducer
publicvoid reduce(Text key, Iterable
String fileName = key.toString().split("/")[1];
doublemaxPro = Math.log(Double.MIN_VALUE);
String maxClassName = "unknown";
for (Text value : values) {
String[] arr = value.toString().split("/");
String className = arr[0];
doublepro = Double.valueOf(arr[1]);
if (pro > maxPro) {
maxPro = pro;
maxClassName = className;
}
}
System.out.println("fileName:" + fileName + ",belong class:" + maxClassName);
// 输出 key:文件名 value:所属分类名以及概率
context.write(new Text(fileName), new Text(maxClassName + ",pro=" + maxPro));
}
}
}
四、数据集说明
训练集:CHINA 文档数255
INDIA 文档数326
TAIWAN 文档数43.
测试集:CHINA 文档个数15
INDIA 文档个数20
TAIWAN 文档个数15
5.1训练数据集打包程序
Map任务个数624(所有小文件的个数) Reduce任务个数1
截图如下
5.2训练文档统计程序
Map任务个数1(输入为1个SequencedFile) Reduce任务个数1
5.3训练单词统计程序
Map任务个数1(输入为1个SequencedFile) Reduce任务个数1
5.4测试数据集打包程序
Map任务个数50(测试数据集小文件个数为50) Reduce任务个数1
5.5测试文档归类程序
Map任务个数1(输入为1个SequencedFile) Reduce任务个数1
测试集文档归类结果截图如下:
针对CHINA 、TAIWAN、 INDIA三个分类下的测试文档进行测试结果如下表所示:
类别(国家) |
正确率 |
召回率 |
F1值 |
CHINA |
18.4% |
46.667% |
26.38% |
INDIA |
42.1% |
80% |
55.67% |
TAIWAN |
39.47% |
100% |
56.60% |