mahout关联规则源码分析 Part 1

最近看了关联规则的相关算法,着重看了mahout的具体实现,mahout官网上面给出了好多算法,具体网址如下:https://cwiki.apache.org/confluence/display/MAHOUT/Parallel+Frequent+Pattern+Mining

先说下命令行运行关联规则,关联规则的算法在mahout-core-0,7.jar包下面,命令行运行如下:

fansy@fansypc:~/hadoop-1.0.2$ bin/hadoop jar ../mahout-pure-0.7/core/target/mahout-core-0.7.jar
org.apache.mahout.fpm.pfpgrowth.FPGrowthDriver -i input/retail.dat -o date1101/fpgrowthdriver00 -s 2 -method mapreduce -regex '[\ ]'
12/11/01 16:31:39 INFO common.AbstractJob:
Command line arguments: {--encoding=[UTF-8], --endPhase=[2147483647], 
--input=[input/retail.dat], --maxHeapSize=[50], --method=[mapreduce], --minSupport=[2], --numGroups=[1000], 
--numTreeCacheEntries=[5], --output=[date1101/fpgrowthdriver00], --splitterPattern=[[\ ]], --startPhase=[0], --tempDir=[temp]}
最后的 -regex '[\ ]' 一定是需要的对于输入数据 retail.dat来说,因为mahout默认的item的分隔符是没有空格的;

而且这里只讨论 并行的程序,所以使用 -method mapreduce

下面分析源码:

在分析源码之前,先看一张图:

mahout关联规则源码分析 Part 1_第1张图片

这张图很好的说明了mahout实现关联规则思想,或者说是流程;

首先,读入数据,比如上图的5个transactions(事务),接着根据一张总表(这张总表是每个item的次数从大到小的一个排列,同时这张表还去除了出现次数小于min_support的item)把这些transactions 去除一些项目并按照总表的顺序排序,得到另外的一个transaction A,接着map的输出就是根据transaction A输出规则,从出现次数最小的item开始输出直到出现次数第二大的item。

Reduce收集map输出相同的key值,把他们的value值放一个集合set 中,然后在统计这些集合中item出现的次数,如果次数大于min_confidence(本例中为3),那么就输出key和此item的规则;

命令行运行时可以看到三个MR,即可以把关联规则的算法分为三部分,但是个人觉得可以分为四个部分,其中的一部分就是总表的获得;鉴于目前本人只看了一个MR和总表的获得部分的源码,今天就只分享这两个部分;

贴代码先,基本都是源码来的,只是稍微改了下:

第一个MR的驱动程序:PFGrowth_ParallelCounting.java:

package org.fansy.date1101.pfgrowth;
import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.common.HadoopUtil;
public class PFGrowth_ParallelCounting {
	public boolean runParallelCountingJob(String input,String output) throws IOException, ClassNotFoundException, InterruptedException{
		Configuration conf=new Configuration();
		Job job = new Job(conf, "Parallel Counting Driver running over input: " + input);
	    job.setJarByClass(PFGrowth_ParallelCounting.class);
	    job.setMapperClass(PFGrowth_ParallelCountingM.class);
	    job.setCombinerClass(PFGrowth_ParallelCountingR.class);
	    job.setReducerClass(PFGrowth_ParallelCountingR.class);
	    job.setOutputFormatClass(SequenceFileOutputFormat.class); // get rid of this line you can get the text file
	    job.setOutputKeyClass(Text.class);
	    job.setOutputValueClass(LongWritable.class);    
	    FileInputFormat.setInputPaths(job,new Path( input));
	    Path outPut=new Path(output,"parallelcounting");
	    HadoopUtil.delete(conf, outPut);
	    FileOutputFormat.setOutputPath(job, outPut);	    
	    boolean succeeded = job.waitForCompletion(true);
	    if (!succeeded) {
	      throw new IllegalStateException("Job failed!");
	    }	
		return succeeded;
	}
}

第一个MR的M:PFGrowth_ParallelCountingM.java:

package org.fansy.date1101.pfgrowth;
import java.io.IOException;
import java.util.regex.Pattern;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
public class PFGrowth_ParallelCountingM extends Mapper<LongWritable,Text,Text,LongWritable> {
	 private static final LongWritable ONE = new LongWritable(1);
	  private Pattern splitter=Pattern.compile("[ ,\t]*[ ,|\t][ ,\t]*");
	  @Override
	  protected void map(LongWritable offset, Text input, Context context) throws IOException,
	                                                                      InterruptedException {
	    String[] items = splitter.split(input.toString());
	    for (String item : items) {
	      if (item.trim().isEmpty()) {
	        continue;
	      }
	      context.setStatus("Parallel Counting Mapper: " + item);
	      context.write(new Text(item), ONE);
	    }
	  }  
}
上面的代码中的间隔符号修改了源码,加上了空格;

第一个MR的R:PFGrowth_ParallelCountingR.java:

package org.fansy.date1101.pfgrowth;
import java.io.IOException;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
public class PFGrowth_ParallelCountingR extends Reducer<Text,LongWritable,Text,LongWritable>{
	protected void reduce(Text key, Iterable<LongWritable> values, Context context) throws IOException,
    		InterruptedException {
		long sum = 0;
		for (LongWritable value : values) {
		context.setStatus("Parallel Counting Reducer :" + key);
		sum += value.get();
		}
		context.setStatus("Parallel Counting Reducer: " + key + " => " + sum);
		context.write(key, new LongWritable(sum));
	}
}
其实第一个MR还是比较好理解的,M分解每个transaction的item,然后输出<item_id ,1>,然后R针对每个item_id 把value值相加求和,这个和wordcount的例子是一样的,当然这里也可以加combine操作的。

接着是总表的获得:

PFGrowth_Driver.java ,同时这个程序也调用第一个MR,也就是说可以直接运行这个文件就可以同时运行第一个MR和获得总表了。

package org.fansy.date1101.pfgrowth;
import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.Parameters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import com.google.common.collect.Lists;
class MyComparator implements Comparator<Pair<String,Long>>{
	 @Override
     public int compare(Pair<String,Long> o1, Pair<String,Long> o2) {
       int ret = o2.getSecond().compareTo(o1.getSecond());
       if (ret != 0) {
         return ret;
       }
       return o1.getFirst().compareTo(o2.getFirst());
     }	
}
public class PFGrowth_Driver {
	public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException{
		if(args.length!=3){
			System.out.println("wrong input args");
			System.out.println("usage: <intput><output><minsupport>");
			System.exit(-1);
		}
		// set parameters
		Parameters params=new Parameters();
		params.set("INPUT", args[0]);
		params.set("OUTPUT", args[1]);
		params.set("MIN_SUPPORT", args[2]);
		// get parameters
		String input=params.get("INPUT");
		String output=params.get("OUTPUT");
		//  run the first job
		PFGrowth_ParallelCounting ppc=new PFGrowth_ParallelCounting();
		ppc.runParallelCountingJob(input, output);	
		//  read input and set the fList
		 List<Pair<String,Long>> fList = readFList(params);
		 Configuration conf=new Configuration();
		 saveFList(fList, params, conf);		 
	}	
	/**
	   * Serializes the fList and returns the string representation of the List
	   * 
	   * @return Serialized String representation of List
	   */
	  public static void saveFList(Iterable<Pair<String,Long>> flist, Parameters params, Configuration conf)
	    throws IOException {
	    Path flistPath = new Path(params.get("OUTPUT"), "fList");
	    FileSystem fs = FileSystem.get(flistPath.toUri(), conf);
	    flistPath = fs.makeQualified(flistPath);
	    HadoopUtil.delete(conf, flistPath);
	    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, flistPath, Text.class, LongWritable.class);
	    try {
	      for (Pair<String,Long> pair : flist) {
	        writer.append(new Text(pair.getFirst()), new LongWritable(pair.getSecond()));
	      }
	    } finally {
	      writer.close();
	    }
	    DistributedCache.addCacheFile(flistPath.toUri(), conf);
	  }
	public static List<Pair<String,Long>> readFList(Parameters params) {
	    int minSupport = Integer.valueOf(params.get("MIN_SUPPORT"));
	    Configuration conf = new Configuration();    
	    Path parallelCountingPath = new Path(params.get("OUTPUT"),"parallelcounting");
	    //  add MyComparator
	    PriorityQueue<Pair<String,Long>> queue = new PriorityQueue<Pair<String,Long>>(11,new MyComparator());
	    // sort according to the occur times from large to small 

	    for (Pair<Text,LongWritable> record
	         : new SequenceFileDirIterable<Text,LongWritable>(new Path(parallelCountingPath, "part-*"),
	                                                        PathType.GLOB, null, null, true, conf)) {
	      long value = record.getSecond().get();
	      if (value >= minSupport) {   // get rid of the item which is below the minimum support
	        queue.add(new Pair<String,Long>(record.getFirst().toString(), value));
	      }
	    }
	    List<Pair<String,Long>> fList = Lists.newArrayList();
	    while (!queue.isEmpty()) {
	      fList.add(queue.poll());
	    }
	    return fList;
	  }	
}

第一个MR运行完毕后,调用readFList()函数,把第一个MR的输出按照item出现的次数从大到小放入一个列表List中,然后调用saveFList()函数把上面求得的List存入HDFS文件中,不过存入的格式是被序列话的,可以另外编写函数查看文件是否和自己的假设相同;

FList 文件反序列化如下:

mahout关联规则源码分析 Part 1_第2张图片





分享,快乐,成长





你可能感兴趣的:(Mahout)