mahout关联规则源码分析 Part 1

最近看了关联规则的相关算法,着重看了mahout的具体实现,mahout官网上面给出了好多算法,具体网址如下:https://cwiki.apache.org/confluence/display/MAHOUT/Parallel+Frequent+Pattern+Mining 。

先说下命令行运行关联规则,关联规则的算法在mahout-core-0,7.jar包下面,命令行运行如下:

 

[java]  view plain copy
 
  1. fansy@fansypc:~/hadoop-1.0.2$ bin/hadoop jar ../mahout-pure-0.7/core/target/mahout-core-0.7.jar  
  2.  org.apache.mahout.fpm.pfpgrowth.FPGrowthDriver -i input/retail.dat -o date1101/fpgrowthdriver00 -s 2 -method mapreduce -regex '[\ ]'  
  3. 12/11/01 16:31:39 INFO common.AbstractJob:  
  4.  Command line arguments: {--encoding=[UTF-8], --endPhase=[2147483647],   
  5. --input=[input/retail.dat], --maxHeapSize=[50], --method=[mapreduce], --minSupport=[2], --numGroups=[1000],   
  6. --numTreeCacheEntries=[5], --output=[date1101/fpgrowthdriver00], --splitterPattern=[[\ ]], --startPhase=[0], --tempDir=[temp]}  

最后的 -regex '[\ ]' 一定是需要的对于输入数据 retail.dat来说,因为mahout默认的item的分隔符是没有空格的;

 

而且这里只讨论 并行的程序,所以使用 -method mapreduce

下面分析源码:

在分析源码之前,先看一张图:

mahout关联规则源码分析 Part 1_第1张图片

这张图很好的说明了mahout实现关联规则思想,或者说是流程;

首先,读入数据,比如上图的5个transactions(事务),接着根据一张总表(这张总表是每个item的次数从大到小的一个排列,同时这张表还去除了出现次数小于min_support的item)把这些transactions 去除一些项目并按照总表的顺序排序,得到另外的一个transaction A,接着map的输出就是根据transaction A输出规则,从出现次数最小的item开始输出直到出现次数第二大的item。

Reduce收集map输出相同的key值,把他们的value值放一个集合set 中,然后在统计这些集合中item出现的次数,如果次数大于min_confidence(本例中为3),那么就输出key和此item的规则;

命令行运行时可以看到三个MR,即可以把关联规则的算法分为三部分,但是个人觉得可以分为四个部分,其中的一部分就是总表的获得;鉴于目前本人只看了一个MR和总表的获得部分的源码,今天就只分享这两个部分;

贴代码先,基本都是源码来的,只是稍微改了下:

第一个MR的驱动程序:PFGrowth_ParallelCounting.java:

 

[java]  view plain copy
 
  1. package org.fansy.date1101.pfgrowth;  
  2. import java.io.IOException;  
  3. import org.apache.hadoop.conf.Configuration;  
  4. import org.apache.hadoop.fs.Path;  
  5. import org.apache.hadoop.io.LongWritable;  
  6. import org.apache.hadoop.io.Text;  
  7. import org.apache.hadoop.mapreduce.Job;  
  8. import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;  
  9. import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;  
  10. import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;  
  11. import org.apache.mahout.common.HadoopUtil;  
  12. public class PFGrowth_ParallelCounting {  
  13.     public boolean runParallelCountingJob(String input,String output) throws IOException, ClassNotFoundException, InterruptedException{  
  14.         Configuration conf=new Configuration();  
  15.         Job job = new Job(conf, "Parallel Counting Driver running over input: " + input);  
  16.         job.setJarByClass(PFGrowth_ParallelCounting.class);  
  17.         job.setMapperClass(PFGrowth_ParallelCountingM.class);  
  18.         job.setCombinerClass(PFGrowth_ParallelCountingR.class);  
  19.         job.setReducerClass(PFGrowth_ParallelCountingR.class);  
  20.         job.setOutputFormatClass(SequenceFileOutputFormat.class); //  get rid of this line you can get the text file  
  21.         job.setOutputKeyClass(Text.class);  
  22.         job.setOutputValueClass(LongWritable.class);      
  23.         FileInputFormat.setInputPaths(job,new Path( input));  
  24.         Path outPut=new Path(output,"parallelcounting");  
  25.         HadoopUtil.delete(conf, outPut);  
  26.         FileOutputFormat.setOutputPath(job, outPut);          
  27.         boolean succeeded = job.waitForCompletion(true);  
  28.         if (!succeeded) {  
  29.           throw new IllegalStateException("Job failed!");  
  30.         }     
  31.         return succeeded;  
  32.     }  
  33. }  

第一个MR的M:PFGrowth_ParallelCountingM.java:

 

 

[java]  view plain copy
 
  1. package org.fansy.date1101.pfgrowth;  
  2. import java.io.IOException;  
  3. import java.util.regex.Pattern;  
  4. import org.apache.hadoop.io.LongWritable;  
  5. import org.apache.hadoop.io.Text;  
  6. import org.apache.hadoop.mapreduce.Mapper;  
  7. public class PFGrowth_ParallelCountingM extends Mapper<LongWritable,Text,Text,LongWritable> {  
  8.      private static final LongWritable ONE = new LongWritable(1);  
  9.       private Pattern splitter=Pattern.compile("[ ,\t]*[ ,|\t][ ,\t]*");  
  10.       @Override  
  11.       protected void map(LongWritable offset, Text input, Context context) throws IOException,  
  12.                                                                           InterruptedException {  
  13.         String[] items = splitter.split(input.toString());  
  14.         for (String item : items) {  
  15.           if (item.trim().isEmpty()) {  
  16.             continue;  
  17.           }  
  18.           context.setStatus("Parallel Counting Mapper: " + item);  
  19.           context.write(new Text(item), ONE);  
  20.         }  
  21.       }    
  22. }  

上面的代码中的间隔符号修改了源码,加上了空格;

 

第一个MR的R:PFGrowth_ParallelCountingR.java:

 

[java]  view plain copy
 
  1. package org.fansy.date1101.pfgrowth;  
  2. import java.io.IOException;  
  3. import org.apache.hadoop.io.LongWritable;  
  4. import org.apache.hadoop.io.Text;  
  5. import org.apache.hadoop.mapreduce.Reducer;  
  6. public class PFGrowth_ParallelCountingR extends Reducer<Text,LongWritable,Text,LongWritable>{  
  7.     protected void reduce(Text key, Iterable<LongWritable> values, Context context) throws IOException,  
  8.             InterruptedException {  
  9.         long sum = 0;  
  10.         for (LongWritable value : values) {  
  11.         context.setStatus("Parallel Counting Reducer :" + key);  
  12.         sum += value.get();  
  13.         }  
  14.         context.setStatus("Parallel Counting Reducer: " + key + " => " + sum);  
  15.         context.write(key, new LongWritable(sum));  
  16.     }  
  17. }  

其实第一个MR还是比较好理解的,M分解每个transaction的item,然后输出<item_id ,1>,然后R针对每个item_id 把value值相加求和,这个和wordcount的例子是一样的,当然这里也可以加combine操作的。

 

接着是总表的获得:

PFGrowth_Driver.java ,同时这个程序也调用第一个MR,也就是说可以直接运行这个文件就可以同时运行第一个MR和获得总表了。

 

[java]  view plain copy
 
  1. package org.fansy.date1101.pfgrowth;  
  2. import java.io.IOException;  
  3. import java.util.Comparator;  
  4. import java.util.List;  
  5. import java.util.PriorityQueue;  
  6. import org.apache.hadoop.conf.Configuration;  
  7. import org.apache.hadoop.filecache.DistributedCache;  
  8. import org.apache.hadoop.fs.FileSystem;  
  9. import org.apache.hadoop.fs.Path;  
  10. import org.apache.hadoop.io.LongWritable;  
  11. import org.apache.hadoop.io.SequenceFile;  
  12. import org.apache.hadoop.io.Text;  
  13. import org.apache.mahout.common.HadoopUtil;  
  14. import org.apache.mahout.common.Pair;  
  15. import org.apache.mahout.common.Parameters;  
  16. import org.apache.mahout.common.iterator.sequencefile.PathType;  
  17. import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;  
  18. import com.google.common.collect.Lists;  
  19. class MyComparator implements Comparator<Pair<String,Long>>{  
  20.      @Override  
  21.      public int compare(Pair<String,Long> o1, Pair<String,Long> o2) {  
  22.        int ret = o2.getSecond().compareTo(o1.getSecond());  
  23.        if (ret != 0) {  
  24.          return ret;  
  25.        }  
  26.        return o1.getFirst().compareTo(o2.getFirst());  
  27.      }    
  28. }  
  29. public class PFGrowth_Driver {  
  30.     public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException{  
  31.         if(args.length!=3){  
  32.             System.out.println("wrong input args");  
  33.             System.out.println("usage: <intput><output><minsupport>");  
  34.             System.exit(-1);  
  35.         }  
  36.         // set parameters  
  37.         Parameters params=new Parameters();  
  38.         params.set("INPUT", args[0]);  
  39.         params.set("OUTPUT", args[1]);  
  40.         params.set("MIN_SUPPORT", args[2]);  
  41.         // get parameters  
  42.         String input=params.get("INPUT");  
  43.         String output=params.get("OUTPUT");  
  44.         //  run the first job  
  45.         PFGrowth_ParallelCounting ppc=new PFGrowth_ParallelCounting();  
  46.         ppc.runParallelCountingJob(input, output);    
  47.         //  read input and set the fList  
  48.          List<Pair<String,Long>> fList = readFList(params);  
  49.          Configuration conf=new Configuration();  
  50.          saveFList(fList, params, conf);           
  51.     }     
  52.     /** 
  53.        * Serializes the fList and returns the string representation of the List 
  54.        *  
  55.        * @return Serialized String representation of List 
  56.        */  
  57.       public static void saveFList(Iterable<Pair<String,Long>> flist, Parameters params, Configuration conf)  
  58.         throws IOException {  
  59.         Path flistPath = new Path(params.get("OUTPUT"), "fList");  
  60.         FileSystem fs = FileSystem.get(flistPath.toUri(), conf);  
  61.         flistPath = fs.makeQualified(flistPath);  
  62.         HadoopUtil.delete(conf, flistPath);  
  63.         SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, flistPath, Text.class, LongWritable.class);  
  64.         try {  
  65.           for (Pair<String,Long> pair : flist) {  
  66.             writer.append(new Text(pair.getFirst()), new LongWritable(pair.getSecond()));  
  67.           }  
  68.         } finally {  
  69.           writer.close();  
  70.         }  
  71.         DistributedCache.addCacheFile(flistPath.toUri(), conf);  
  72.       }  
  73.     public static List<Pair<String,Long>> readFList(Parameters params) {  
  74.         int minSupport = Integer.valueOf(params.get("MIN_SUPPORT"));  
  75.         Configuration conf = new Configuration();      
  76.         Path parallelCountingPath = new Path(params.get("OUTPUT"),"parallelcounting");  
  77.         //  add MyComparator  
  78.         PriorityQueue<Pair<String,Long>> queue = new PriorityQueue<Pair<String,Long>>(11,new MyComparator());  
  79.         // sort according to the occur times from large to small   
  80.   
  81.         for (Pair<Text,LongWritable> record  
  82.              : new SequenceFileDirIterable<Text,LongWritable>(new Path(parallelCountingPath, "part-*"),  
  83.                                                             PathType.GLOB, nullnulltrue, conf)) {  
  84.           long value = record.getSecond().get();  
  85.           if (value >= minSupport) {   // get rid of the item which is below the minimum support  
  86.             queue.add(new Pair<String,Long>(record.getFirst().toString(), value));  
  87.           }  
  88.         }  
  89.         List<Pair<String,Long>> fList = Lists.newArrayList();  
  90.         while (!queue.isEmpty()) {  
  91.           fList.add(queue.poll());  
  92.         }  
  93.         return fList;  
  94.       }   
  95. }  

第一个MR运行完毕后,调用readFList()函数,把第一个MR的输出按照item出现的次数从大到小放入一个列表List中,然后调用saveFList()函数把上面求得的List存入HDFS文件中,不过存入的格式是被序列话的,可以另外编写函数查看文件是否和自己的假设相同;

 

FList 文件反序列化如下:

mahout关联规则源码分析 Part 1_第2张图片

http://blog.csdn.net/fansy1990/article/details/8137942

你可能感兴趣的:(机器学习)