关联规则二项集hadoop实现

近期看mahout的关联规则源码,颇为头痛,本来打算写一个系列分析关联规则的源码的,但是后面看到有点乱了,可能是稍微有点复杂吧,所以就打算先实现最简单的二项集关联规则。

算法的思想还是参考上次的图片:

关联规则二项集hadoop实现_第1张图片

这里实现分为五个步骤:

  1. 针对原始输入计算每个项目出现的次数;
  2. 按出现次数从大到小(排除出现次数小于阈值的项目)生成frequence list file;
  3. 针对原始输入的事务进行按frequence list file进行排序并剪枝;
  4. 生成二项集规则;
  5. 计算二项集规则出现的次数,并删除小于阈值的二项集规则;

第一步的实现:包括步骤1和步骤2,代码如下:

GetFlist.java:

[java]  view plain copy
 
  1. package org.fansy.date1108.fpgrowth.twodimension;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.IOException;  
  5. import java.io.InputStreamReader;  
  6. import java.util.ArrayList;  
  7. import java.util.Comparator;  
  8. import java.util.Iterator;  
  9. import java.util.List;  
  10. import java.util.PriorityQueue;  
  11. import java.util.regex.Pattern;  
  12.   
  13. import org.apache.hadoop.conf.Configuration;  
  14. import org.apache.hadoop.fs.FSDataInputStream;  
  15. import org.apache.hadoop.fs.FSDataOutputStream;  
  16. import org.apache.hadoop.fs.FileSystem;  
  17. import org.apache.hadoop.fs.Path;  
  18. import org.apache.hadoop.io.IntWritable;  
  19. import org.apache.hadoop.io.LongWritable;  
  20. import org.apache.hadoop.io.Text;  
  21. import org.apache.hadoop.mapreduce.Job;  
  22. import org.apache.hadoop.mapreduce.Mapper;  
  23. import org.apache.hadoop.mapreduce.Reducer;  
  24. import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;  
  25. import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;  
  26.   
  27. //  the specific comparator  
  28. class MyComparator implements Comparator<String>{  
  29.     private String splitter=",";  
  30.     public MyComparator(String splitter){  
  31.         this.splitter=splitter;  
  32.     }  
  33.     @Override  
  34.     public int compare(String o1, String o2) {  
  35.         // TODO Auto-generated method stub  
  36.         String[] str1=o1.toString().split(splitter);  
  37.         String[] str2=o2.toString().split(splitter);  
  38.         int num1=Integer.parseInt(str1[1]);  
  39.         int num2=Integer.parseInt(str2[1]);  
  40.         if(num1>num2){  
  41.             return -1;  
  42.         }else if(num1<num2){  
  43.             return 1;  
  44.         }else{  
  45.             return str1[0].compareTo(str2[0]);  
  46.         }  
  47.     }  
  48. }  
  49.   
  50. public class GetFList {  
  51.     /** 
  52.      *  the program is based on the picture  
  53.      */  
  54.     // Mapper  
  55.     public static class  MapperGF extends Mapper<LongWritable ,Text ,Text,IntWritable>{  
  56.         private Pattern splitter=Pattern.compile("[ ]*[ ,|\t]");  
  57.         private final IntWritable newvalue=new IntWritable(1);  
  58.         public void map(LongWritable key,Text value,Context context) throws IOException, InterruptedException{  
  59.             String [] items=splitter.split(value.toString());  
  60.             for(String item:items){  
  61.                 context.write(new Text(item), newvalue);  
  62.             }  
  63.         }  
  64.     }  
  65.     // Reducer  
  66.     public static class ReducerGF extends Reducer<Text,IntWritable,Text ,IntWritable>{  
  67.         public void reduce(Text key,Iterable<IntWritable> value,Context context) throws IOException, InterruptedException{  
  68.             int temp=0;  
  69.             for(IntWritable v:value){  
  70.                 temp+=v.get();  
  71.             }  
  72.             context.write(key, new IntWritable(temp));  
  73.         }  
  74.     }  
  75.     public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {  
  76.         // TODO Auto-generated method stub  
  77.   
  78.         if(args.length!=3){  
  79.             System.out.println("Usage: <input><output><min_support>");  
  80.             System.exit(1);  
  81.         }  
  82.         String input=args[0];  
  83.         String output=args[1];  
  84.         int minSupport=0;  
  85.         try {  
  86.             minSupport=Integer.parseInt(args[2]);  
  87.         } catch (NumberFormatException e) {  
  88.             // TODO Auto-generated catch block  
  89.             minSupport=3;  
  90.         }  
  91.         Configuration conf=new Configuration();  
  92.         String temp=args[1]+"_temp";  
  93.         Job job=new Job(conf,"the get flist job");  
  94.         job.setJarByClass(GetFList.class);  
  95.         job.setMapperClass(MapperGF.class);  
  96.         job.setCombinerClass(ReducerGF.class);  
  97.         job.setReducerClass(ReducerGF.class);  
  98.         job.setOutputKeyClass(Text.class);  
  99.         job.setOutputValueClass(IntWritable.class);       
  100.         FileInputFormat.setInputPaths(job, new Path(input));  
  101.         FileOutputFormat.setOutputPath(job, new Path(temp));          
  102.         boolean succeed=job.waitForCompletion(true);  
  103.         if(succeed){          
  104.             //  read the temp output and write the data to the final output  
  105.             List<String> list=readFList(temp+"/part-r-00000",minSupport);  
  106.             System.out.println("the frequence list has generated ... ");  
  107.             // generate the frequence file  
  108.             generateFList(list,output);  
  109.             System.out.println("the frequence file has generated ... ");  
  110.         }else{  
  111.             System.out.println("the job is failed");  
  112.             System.exit(1);  
  113.         }                 
  114.     }  
  115.     //  read the temp_output and return the frequence list  
  116.     public static List<String> readFList(String input,int minSupport) throws IOException{  
  117.         // read the hdfs file  
  118.         Configuration conf=new Configuration();  
  119.         Path path=new Path(input);  
  120.            FileSystem fs=FileSystem.get(path.toUri(),conf);  
  121.         FSDataInputStream in1=fs.open(path);  
  122.         PriorityQueue<String> queue=new PriorityQueue<String>(15,new MyComparator("\t"));     
  123.         InputStreamReader isr1=new InputStreamReader(in1);  
  124.         BufferedReader br=new BufferedReader(isr1);  
  125.         String line;  
  126.         while((line=br.readLine())!=null){  
  127.             int num=0;  
  128.             try {  
  129.                     num=Integer.parseInt(line.split("\t")[1]);  
  130.             } catch (NumberFormatException e) {  
  131.                 // TODO Auto-generated catch block  
  132.                 num=0;  
  133.             }  
  134.             if(num>minSupport){  
  135.                 queue.add(line);  
  136.             }  
  137.         }  
  138.         br.close();  
  139.         isr1.close();  
  140.         in1.close();  
  141.         List<String> list=new ArrayList<String>();  
  142.         while(!queue.isEmpty()){  
  143.             list.add(queue.poll());  
  144.         }  
  145.         return list;  
  146.     }  
  147.     // generate the frequence file  
  148.     public static void generateFList(List<String> list,String output) throws IOException{  
  149.         Configuration conf=new Configuration();  
  150.         Path path=new Path(output);  
  151.         FileSystem fs=FileSystem.get(path.toUri(),conf);  
  152.         FSDataOutputStream writer=fs.create(path);  
  153.         Iterator<String> i=list.iterator();  
  154.         while(i.hasNext()){  
  155.             writer.writeBytes(i.next()+"\n");//  in the last line add a \n which is not supposed to exist  
  156.         }  
  157.         writer.close();  
  158.     }  
  159. }  

步骤1的实现其实就是最简单的wordcount程序的实现,在步骤2中涉及到HDFS文件的读取以及写入。在生成frequence list file时排序时用到了PriorityQueue类,同时自定义了一个类用来定义排序规则;

第二步:步骤3,代码如下:

SortAndCut.java:

[java]  view plain copy
 
  1. package org.fansy.date1108.fpgrowth.twodimension;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.IOException;  
  5. import java.io.InputStreamReader;  
  6. import java.net.URI;  
  7. import java.util.HashSet;  
  8. import java.util.Iterator;  
  9. import java.util.LinkedHashSet;  
  10. import java.util.Set;  
  11. import java.util.regex.Pattern;  
  12.   
  13. import org.apache.hadoop.conf.Configuration;  
  14. import org.apache.hadoop.fs.FSDataInputStream;  
  15. import org.apache.hadoop.fs.FileSystem;  
  16. import org.apache.hadoop.fs.Path;  
  17. import org.apache.hadoop.io.LongWritable;  
  18. import org.apache.hadoop.io.NullWritable;  
  19. import org.apache.hadoop.io.Text;  
  20. import org.apache.hadoop.mapreduce.Job;  
  21. import org.apache.hadoop.mapreduce.Mapper;  
  22. import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;  
  23. import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;  
  24.   
  25. public class SortAndCut {  
  26.     /** 
  27.      *  sort and cut the items 
  28.      */   
  29.     public static class M extends Mapper<LongWritable,Text,NullWritable,Text>{  
  30.         private LinkedHashSet<String> list=new LinkedHashSet<String>();  
  31.         private Pattern splitter=Pattern.compile("[ ]*[ ,|\t]");  
  32.           
  33.         public void setup(Context context) throws IOException{  
  34.             String input=context.getConfiguration().get("FLIST");  
  35.              FileSystem fs=FileSystem.get(URI.create(input),context.getConfiguration());  
  36.                 Path path=new Path(input);  
  37.                 FSDataInputStream in1=fs.open(path);  
  38.                 InputStreamReader isr1=new InputStreamReader(in1);  
  39.                 BufferedReader br=new BufferedReader(isr1);  
  40.                 String line;  
  41.                 while((line=br.readLine())!=null){  
  42.                     String[] str=line.split("\t");  
  43.                     if(str.length>0){  
  44.                         list.add(str[0]);  
  45.                     }  
  46.                 }  
  47.         }  
  48.         // map  
  49.         public void map(LongWritable key,Text value,Context context) throws IOException, InterruptedException{  
  50.             String [] items=splitter.split(value.toString());  
  51.             Set<String> set=new HashSet<String>();  
  52.             set.clear();  
  53.             for(String s:items){  
  54.                 set.add(s);  
  55.             }  
  56.             Iterator<String> iter=list.iterator();  
  57.             StringBuffer sb=new StringBuffer();  
  58.             sb.setLength(0);  
  59.             int num=0;  
  60.             while(iter.hasNext()){  
  61.                 String item=iter.next();  
  62.                 if(set.contains(item)){  
  63.                     sb.append(item+",");  
  64.                     num++;  
  65.                 }  
  66.             }  
  67.             if(num>0){  
  68.                 context.write(NullWritable.get(), new Text(sb.toString()));  
  69.             }  
  70.         }  
  71.     }  
  72.     public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {  
  73.         // TODO Auto-generated method stub  
  74.         if(args.length!=3){  
  75.             System.out.println("Usage: <input><output><fListPath>");  
  76.             System.exit(1);  
  77.         }  
  78.         String input=args[0];  
  79.         String output=args[1];  
  80.         String fListPath=args[2];  
  81.         Configuration conf=new Configuration();  
  82.         conf.set("FLIST", fListPath);  
  83.         Job job=new Job(conf,"the sort and cut  the items  job");  
  84.         job.setJarByClass(SortAndCut.class);  
  85.         job.setMapperClass(M.class);  
  86.         job.setNumReduceTasks(0);  
  87.         job.setOutputKeyClass(NullWritable.class);  
  88.         job.setOutputValueClass(Text.class);      
  89.         FileInputFormat.setInputPaths(job, new Path(input));  
  90.         FileOutputFormat.setOutputPath(job, new Path(output));    
  91.         boolean succeed=job.waitForCompletion(true);  
  92.         if(succeed){  
  93.             System.out.println(job.getJobName()+" succeed ... ");  
  94.         }  
  95.     }  
  96. }  

在本阶段的Mapper的setup中读取frequence file到一个LinkedHashSet(可以保持原始的插入顺序)中,然后在map中针对一个事务输出这个LinkedHashSet,不过限制输出是在这个事务中出现的项目而已。

第三步:步骤4和步骤5,代码如下:

OutRules.java

[java]  view plain copy
 
  1. package org.fansy.date1108.fpgrowth.twodimension;  
  2.   
  3. import java.io.IOException;  
  4. import java.util.HashMap;  
  5. import java.util.Iterator;  
  6. import java.util.Map.Entry;  
  7. import java.util.Stack;  
  8. import java.util.TreeSet;  
  9.   
  10. import org.apache.hadoop.conf.Configuration;  
  11. import org.apache.hadoop.fs.Path;  
  12. import org.apache.hadoop.io.LongWritable;  
  13. import org.apache.hadoop.io.Text;  
  14. import org.apache.hadoop.mapreduce.Job;  
  15. import org.apache.hadoop.mapreduce.Mapper;  
  16. import org.apache.hadoop.mapreduce.Reducer;  
  17. import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;  
  18. import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;  
  19. public class OutRules {  
  20.       
  21.     public static class M extends Mapper<LongWritable,Text,Text,Text>{  
  22.         public void map(LongWritable key,Text value,Context context) throws IOException, InterruptedException{  
  23.             String str=value.toString();  
  24.             String[] s=str.split(",");  
  25.             if(s.length<=1){  
  26.                 return;  
  27.             }  
  28.             Stack<String> stack=new Stack<String>();  
  29.             for(int i=0;i<s.length;i++){  
  30.                 stack.push(s[i]);  
  31.             }  
  32.             int num=str.length();  
  33.             while(stack.size()>1){  
  34.                 num=num-2;  
  35.                 context.write(new Text(stack.pop()),new Text(str.substring(0,num)));  
  36.             }  
  37.         }  
  38.     }  
  39.     // Reducer  
  40.     public static class R extends Reducer<Text ,Text,Text,Text>{  
  41.         private int minConfidence=0;  
  42.         public void setup(Context context){  
  43.             String str=context.getConfiguration().get("MIN");  
  44.             try {  
  45.                 minConfidence=Integer.parseInt(str);  
  46.             } catch (NumberFormatException e) {  
  47.                 // TODO Auto-generated catch block  
  48.                 minConfidence=3;  
  49.             }  
  50.         }  
  51.         public void reduce(Text key,Iterable<Text> values,Context context) throws IOException, InterruptedException{  
  52.             HashMap<String,Integer> hm=new HashMap<String ,Integer>();  
  53.             for(Text v:values){  
  54.                 String[] str=v.toString().split(",");  
  55.                 for(int i=0;i<str.length;i++){  
  56.                     if(hm.containsKey(str[i])){  
  57.                         int temp=hm.get(str[i]);  
  58.                         hm.put(str[i], temp+1);  
  59.                     }else{  
  60.                         hm.put(str[i], 1);  
  61.                     }  
  62.                 }  
  63.             }  
  64.             //  end of for  
  65.             TreeSet<String> sss=new TreeSet<String>(new MyComparator(" "));  
  66.             Iterator<Entry<String,Integer>> iter=hm.entrySet().iterator();  
  67.             while(iter.hasNext()){  
  68.                 Entry<String,Integer> k=iter.next();  
  69.                 if(k.getValue()>minConfidence&&!key.toString().equals(k.getKey())){  
  70.                     sss.add(k.getKey()+" "+k.getValue());  
  71.                 }  
  72.             }  
  73.             Iterator<String> iters=sss.iterator();  
  74.             StringBuffer sb=new StringBuffer();  
  75.             while(iters.hasNext()){  
  76.                 sb.append(iters.next()+"|");  
  77.             }  
  78.             context.write(key, new Text(":\t"+sb.toString()));  
  79.         }  
  80.     }  
  81.     public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {  
  82.         // TODO Auto-generated method stub  
  83.         if(args.length!=3){  
  84.             System.out.println("Usage: <input><output><min_confidence>");  
  85.             System.exit(1);  
  86.         }  
  87.         String input=args[0];  
  88.         String output=args[1];  
  89.         String minConfidence=args[2];     
  90.         Configuration conf=new Configuration();  
  91.         conf.set("MIN", minConfidence);  
  92.         Job job=new Job(conf,"the out rules   job");  
  93.         job.setJarByClass(OutRules.class);  
  94.         job.setMapperClass(M.class);  
  95.         job.setNumReduceTasks(1);  
  96.         job.setReducerClass(R.class);  
  97.         job.setOutputKeyClass(Text.class);  
  98.         job.setOutputValueClass(Text.class);  
  99.         FileInputFormat.setInputPaths(job, new Path(input));  
  100.         FileOutputFormat.setOutputPath(job, new Path(output));    
  101.         boolean succeed=job.waitForCompletion(true);  
  102.         if(succeed){  
  103.             System.out.println(job.getJobName()+" succeed ... ");  
  104.         }  
  105.     }  
  106. }  

在map阶段使用了Stack 和字符串操作实现类似下面的功能:

[java]  view plain copy
 
  1. input:p,x,z,y,a,b  
  2. output:  
  3. b:p,x,z,y,a  
  4. a:p,x,z,y  
  5. y:p,x,z  
  6. z:p,x  
  7. x:p  

在reduce阶段只是统计下项目出现的次数而已,用到了一个HashMap,又如果输出是根据项目出现的次数从大到小的一个排序那就更好了,所以又用到了TreeSet.

其中上面所有的输出文件中的格式都只是拼串而已,所以其中的格式可以按照自己的要求进行更改。

比如,我的输出如下:

[java]  view plain copy
 
  1. 0   :   39 125|48 99|32 44|41 37|38 26|310 17|5 16|65 14|1 13|89 13|1144 12|225 12|60 11|604 11|  
  2. 1327 10|237 10|101 9|147 9|270 9|533 9|9 9|107 8|11 8|117 8|170 8|271 8|334 8|549 8|62 8|812 8|10 7|  
  3. 1067 7|12925 7|23 7|255 7|279 7|548 7|783 7|14098 6|2 6|208 6|22 6|36 6|413 6|789 6|824 6|961 6|110 5|  
  4. 120 5|12933 5|201 5|2238 5|2440 5|2476 5|251 5|286 5|2879 5|3 5|4105 5|415 5|438 5|467 5|475 5|479 5|49 5|  
  5. 592 5|675 5|715 5|740 5|791 5|830 5|921 5|9555 5|976 5|979 5|1001 4|1012 4|1027 4|1055 4|1146 4|12 4|13334 4|  
  6. 136 4|1393 4|16 4|1600 4|165 4|167 4|1819 4|1976 4|2051 4|2168 4|2215 4|2284 4|2353 4|2524 4|261 4|267 4|269 4|  
  7. 27 4|2958 4|297 4|3307 4|338 4|420 4|4336 4|4340 4|488 4|4945 4|5405 4|58 4|589 4|75 4|766 4|795 4|809 4|880 4|8978 4|916 4|94 4|956 4|  

冒号前面是项目,后面的39是项目再后面是<0,39>出现的次数,即125次,<0,48>出现的次数是99次;

总结,mahout的源代码确实比较难啃,所以要先对算法非常熟悉,然后去看源码的话 应该会比较容易点;

http://blog.csdn.net/fansy1990/article/details/8160956

你可能感兴趣的:(机器学习)