之前写的代码都是单机上跑的,发现现在很流行hadoop,所以又试着用hadoop mapreduce来处理下决策树的创建。因为hadoop接触的也不多,所以写的不好,勿怪。
看了一些mahout在处理决策树和随机森林的过程,大体过程是Job只有一个Mapper处理,在map方法里面做数据的转换收集工作,然后在cleanup方法里面去做决策树的创建过程。然后将决策树序列化到HDFS上面,分类样本数据集的时候,在从HDFS上面取回决策树结构。大体来说,mahout决策树的构建过程好像并没有结合分布式计算,因为我也并没有仔仔细细的去研读mahout里面的源码,所以可能是我没发现。下面是我实现的一个简单hadoop版本决策树,用的C4.5算法,通过MapReduce去计算增益率。最后生成的决策树并未保存在HDFS上面,后面有时间在考虑下吧。下面是具体代码实现:
public class DecisionTreeC45Job extends AbstractJob { /** 对数据集做准备工作,主要就是将填充好默认值的数据集再次传到HDFS上*/ public String prepare(Data trainData) { String path = FileUtils.obtainRandomTxtPath(); DataHandler.writeData(path, trainData); System.out.println(path); String name = path.substring(path.lastIndexOf(File.separator) + 1); String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name; HDFSUtils.copyFromLocalFile(conf, path, hdfsPath); return hdfsPath; } /** 选择最佳属性,读取MapReduce计算后产生的文件,取增益率最大*/ public AttributeGainWritable chooseBestAttribute(String output) { AttributeGainWritable maxAttribute = null; Path path = new Path(output); try { FileSystem fs = path.getFileSystem(conf); Path[] paths = HDFSUtils.getPathFiles(fs, path); ShowUtils.print(paths); double maxGainRatio = 0.0; SequenceFile.Reader reader = null; for (Path p : paths) { reader = new SequenceFile.Reader(fs, p, conf); Text key = (Text) ReflectionUtils.newInstance( reader.getKeyClass(), conf); AttributeGainWritable value = new AttributeGainWritable(); while (reader.next(key, value)) { double gainRatio = value.getGainRatio(); if (gainRatio >= maxGainRatio) { maxGainRatio = gainRatio; maxAttribute = value; } value = new AttributeGainWritable(); } IOUtils.closeQuietly(reader); } System.out.println("output: " + path.toString()); HDFSUtils.delete(conf, path); System.out.println("hdfs delete file : " + path.toString()); } catch (IOException e) { e.printStackTrace(); } return maxAttribute; } /** 构造决策树 */ public Object build(String input, Data data) { Object preHandleResult = preHandle(data); if (null != preHandleResult) return preHandleResult; String output = HDFSUtils.HDFS_TEMP_OUTPUT_URL; HDFSUtils.delete(conf, new Path(output)); System.out.println("delete output path : " + output); String[] paths = new String[]{input, output}; //通过MapReduce计算增益率 CalculateC45GainRatioMR.main(paths); AttributeGainWritable bestAttr = chooseBestAttribute(output); String attribute = bestAttr.getAttribute(); System.out.println("best attribute: " + attribute); System.out.println("isCategory: " + bestAttr.isCategory()); if (bestAttr.isCategory()) { return attribute; } String[] splitPoints = bestAttr.obtainSplitPoints(); System.out.print("splitPoints: "); ShowUtils.print(splitPoints); TreeNode treeNode = new TreeNode(attribute); String[] attributes = data.getAttributesExcept(attribute); //分割数据集,并将分割后的数据集传到HDFS上 DataSplit dataSplit = DataHandler.split(new Data( data.getInstances(), attribute, splitPoints)); for (DataSplitItem item : dataSplit.getItems()) { String path = item.getPath(); String name = path.substring(path.lastIndexOf(File.separator) + 1); String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name; HDFSUtils.copyFromLocalFile(conf, path, hdfsPath); treeNode.setChild(item.getSplitPoint(), build(hdfsPath, new Data(attributes, item.getInstances()))); } return treeNode; } /** 分类,根据决策树节点判断测试样本集的类型,并将结果上传到HDFS上*/ private void classify(TreeNode treeNode, String trainSet, String testSet, String output) { OutputStream out = null; BufferedWriter writer = null; try { Path trainSetPath = new Path(trainSet); FileSystem trainFS = trainSetPath.getFileSystem(conf); Path[] trainHdfsPaths = HDFSUtils.getPathFiles(trainFS, trainSetPath); FSDataInputStream trainFSInputStream = trainFS.open(trainHdfsPaths[0]); Data trainData = DataLoader.load(trainFSInputStream, true); Path testSetPath = new Path(testSet); FileSystem testFS = testSetPath.getFileSystem(conf); Path[] testHdfsPaths = HDFSUtils.getPathFiles(testFS, testSetPath); FSDataInputStream fsInputStream = testFS.open(testHdfsPaths[0]); Data testData = DataLoader.load(fsInputStream, true); DataHandler.fill(testData.getInstances(), trainData.getAttributes(), 0); Object[] results = (Object[]) treeNode.classify(testData); ShowUtils.print(results); DataError dataError = new DataError(testData.getCategories(), results); dataError.report(); String path = FileUtils.obtainRandomTxtPath(); out = new FileOutputStream(new File(path)); writer = new BufferedWriter(new OutputStreamWriter(out)); StringBuilder sb = null; for (int i = 0, len = results.length; i < len; i++) { sb = new StringBuilder(); sb.append(i+1).append("\t").append(results[i]); writer.write(sb.toString()); writer.newLine(); } writer.flush(); Path outputPath = new Path(output); FileSystem fs = outputPath.getFileSystem(conf); if (!fs.exists(outputPath)) { fs.mkdirs(outputPath); } String name = path.substring(path.lastIndexOf(File.separator) + 1); HDFSUtils.copyFromLocalFile(conf, path, output + File.separator + name); } catch (IOException e) { e.printStackTrace(); } finally { IOUtils.closeQuietly(out); IOUtils.closeQuietly(writer); } } public void run(String[] args) { try { if (null == conf) conf = new Configuration(); String[] inputArgs = new GenericOptionsParser( conf, args).getRemainingArgs(); if (inputArgs.length != 3) { System.out.println("error, please input three path."); System.out.println("1. trainset path."); System.out.println("2. testset path."); System.out.println("3. result output path."); System.exit(2); } Path input = new Path(inputArgs[0]); FileSystem fs = input.getFileSystem(conf); Path[] hdfsPaths = HDFSUtils.getPathFiles(fs, input); FSDataInputStream fsInputStream = fs.open(hdfsPaths[0]); Data trainData = DataLoader.load(fsInputStream, true); /** 填充缺失属性的默认值*/ DataHandler.fill(trainData, 0); String hdfsInput = prepare(trainData); TreeNode treeNode = (TreeNode) build(hdfsInput, trainData); TreeNodeHelper.print(treeNode, 0, null); classify(treeNode, inputArgs[0], inputArgs[1], inputArgs[2]); } catch (Exception e) { e.printStackTrace(); } } public static void main(String[] args) { DecisionTreeC45Job job = new DecisionTreeC45Job(); long startTime = System.currentTimeMillis(); job.run(args); long endTime = System.currentTimeMillis(); System.out.println("spend time: " + (endTime - startTime)); } }
CalculateC45GainRatioMR具体实现:
public class CalculateC45GainRatioMR { private static void configureJob(Job job) { job.setJarByClass(CalculateC45GainRatioMR.class); job.setMapperClass(CalculateC45GainRatioMapper.class); job.setMapOutputKeyClass(Text.class); job.setMapOutputValueClass(AttributeWritable.class); job.setReducerClass(CalculateC45GainRatioReducer.class); job.setOutputKeyClass(Text.class); job.setOutputValueClass(AttributeGainWritable.class); job.setInputFormatClass(TextInputFormat.class); job.setOutputFormatClass(SequenceFileOutputFormat.class); } public static void main(String[] args) { Configuration configuration = new Configuration(); try { String[] inputArgs = new GenericOptionsParser( configuration, args).getRemainingArgs(); if (inputArgs.length != 2) { System.out.println("error, please input two path. input and output"); System.exit(2); } Job job = new Job(configuration, "Decision Tree"); FileInputFormat.setInputPaths(job, new Path(inputArgs[0])); FileOutputFormat.setOutputPath(job, new Path(inputArgs[1])); configureJob(job); System.out.println(job.waitForCompletion(true) ? 0 : 1); } catch (Exception e) { e.printStackTrace(); } } } class CalculateC45GainRatioMapper extends Mapper<LongWritable, Text, Text, AttributeWritable> { @Override protected void setup(Context context) throws IOException, InterruptedException { super.setup(context); } @Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { String line = value.toString(); StringTokenizer tokenizer = new StringTokenizer(line); Long id = Long.parseLong(tokenizer.nextToken()); String category = tokenizer.nextToken(); boolean isCategory = true; while (tokenizer.hasMoreTokens()) { isCategory = false; String attribute = tokenizer.nextToken(); String[] entry = attribute.split(":"); context.write(new Text(entry[0]), new AttributeWritable(id, category, entry[1])); } if (isCategory) { context.write(new Text(category), new AttributeWritable(id, category, category)); } } @Override protected void cleanup(Context context) throws IOException, InterruptedException { super.cleanup(context); } } class CalculateC45GainRatioReducer extends Reducer<Text, AttributeWritable, Text, AttributeGainWritable> { @Override protected void setup(Context context) throws IOException, InterruptedException { super.setup(context); } @Override protected void reduce(Text key, Iterable<AttributeWritable> values, Context context) throws IOException, InterruptedException { String attributeName = key.toString(); double totalNum = 0.0; Map<String, Map<String, Integer>> attrValueSplits = new HashMap<String, Map<String, Integer>>(); Iterator<AttributeWritable> iterator = values.iterator(); boolean isCategory = false; while (iterator.hasNext()) { AttributeWritable attribute = iterator.next(); String attributeValue = attribute.getAttributeValue(); if (attributeName.equals(attributeValue)) { isCategory = true; break; } Map<String, Integer> attrValueSplit = attrValueSplits.get(attributeValue); if (null == attrValueSplit) { attrValueSplit = new HashMap<String, Integer>(); attrValueSplits.put(attributeValue, attrValueSplit); } String category = attribute.getCategory(); Integer categoryNum = attrValueSplit.get(category); attrValueSplit.put(category, null == categoryNum ? 1 : categoryNum + 1); totalNum++; } if (isCategory) { System.out.println("is Category"); int sum = 0; iterator = values.iterator(); while (iterator.hasNext()) { iterator.next(); sum += 1; } System.out.println("sum: " + sum); context.write(key, new AttributeGainWritable(attributeName, sum, true, null)); } else { double gainInfo = 0.0; double splitInfo = 0.0; for (Map<String, Integer> attrValueSplit : attrValueSplits.values()) { double totalCategoryNum = 0; for (Integer categoryNum : attrValueSplit.values()) { totalCategoryNum += categoryNum; } double entropy = 0.0; for (Integer categoryNum : attrValueSplit.values()) { double p = categoryNum / totalCategoryNum; entropy -= p * (Math.log(p) / Math.log(2)); } double dj = totalCategoryNum / totalNum; gainInfo += dj * entropy; splitInfo -= dj * (Math.log(dj) / Math.log(2)); } double gainRatio = splitInfo == 0.0 ? 0.0 : gainInfo / splitInfo; StringBuilder splitPoints = new StringBuilder(); for (String attrValue : attrValueSplits.keySet()) { splitPoints.append(attrValue).append(","); } splitPoints.deleteCharAt(splitPoints.length() - 1); System.out.println("attribute: " + attributeName); System.out.println("gainRatio: " + gainRatio); System.out.println("splitPoints: " + splitPoints.toString()); context.write(key, new AttributeGainWritable(attributeName, gainRatio, false, splitPoints.toString())); } } @Override protected void cleanup(Context context) throws IOException, InterruptedException { super.cleanup(context); } }