低版本的Hadoop实现 Apriori 算法Java代码

Apriori 关联规则挖掘的一种算法,其逻辑简单,网上有很多关于算法逻辑的介绍,在此不再赘述。使用hadoop实现Apriori算法的核心在于,循环计算,在Map过程对候选项进行识别,Combine过程和Reduce 过程实现对候选项集的频次统计,并最终输出满足条件的项集合。
可以这样理解:

  1. 根据输入文件,识别所有候选项A,B,C…
  2. Reduce过程统计ABC等候选项的频次,输出不小于最小支持数量的项集;
  3. 此时上一个MapReduc过程已经结束,保存输出结果后,开启下一次的Map过程,结合上一个过程的输出,对候选项做连接操作,如生成AB,AC等,Map过程扫描源数据,类似Wordcount的Map过程;
  4. Combine和Reduce过程简单时间频率统计,输出满足条件的项集;
  5. 重复3-4直到输出内容为空。

按照逻辑过程,实现代码如下,首先需要辅助类Assitance .java,实现连接项集,保存输出结果等操作:

package myapriori;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.LineReader;

public class Assitance {

    public static List> getNextRecord(String nextrecord,
            String isDirectory) {
        boolean isdy = false ;
        if(isDirectory.equals("true")){
            isdy = true;
        }
        List> result = new ArrayList>();

        try {
            Path path = new Path(nextrecord);

            Configuration conf = new Configuration();

            FileSystem fileSystem = path.getFileSystem(conf);

            if (isdy) {
                FileStatus[] listFile = fileSystem.listStatus(path);
                for (int i = 0; i < listFile.length; i++) {
                    result.addAll(getNextRecord(listFile[i].getPath()
                            .toString(), "false"));
                }
                return result;
            }

            FSDataInputStream fsis = fileSystem.open(path);
            LineReader lineReader = new LineReader(fsis, conf);

            Text line = new Text();
            while (lineReader.readLine(line) > 0) {
                List tempList = new ArrayList();
                // ArrayList tempList = textToArray(line);

                String[] fields = line.toString()
                        .substring(0, line.toString().indexOf("]"))
                        .replaceAll("\\[", "").replaceAll("\\]", "").replaceAll("\t", "").split(",");
                for (int i = 0; i < fields.length; i++) {
                    tempList.add(fields[i].trim());
                }
                Collections.sort(tempList);
                result.add(tempList);

            }
            lineReader.close();
            result = connectRecord(result);
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }


        return result;
    }

    private static List> connectRecord(List> result) {

        List> nextCandidateItemset = new ArrayList>();
        for (int i = 0; i < result.size()-1; i++) {

            HashSet hsSet = new HashSet();
            HashSet hsSettemp = new HashSet();
            for (int k = 0; k < result.get(i).size(); k++)
                // 获得频繁集第i行
                hsSet.add(result.get(i).get(k));
            int hsLength_before = hsSet.size();// 添加前长度
            hsSettemp = (HashSet) hsSet.clone();
            for (int h = i + 1; h < result.size(); h++) {// 频繁集第i行与第j行(j>i)连接
                                                            // 每次添加且添加一个元素组成
                                                            // 新的频繁项集的某一行,
                hsSet = (HashSet) hsSettemp.clone();// !!!做连接的hasSet保持不变
                for (int j = 0; j < result.get(h).size(); j++)
                    hsSet.add(result.get(h).get(j));
                int hsLength_after = hsSet.size();
                if (hsLength_before + 1 == hsLength_after
                        && isnotHave(hsSet, nextCandidateItemset)
                        && isSubSet(hsSet, result)) {
                    // 如果不相等,表示添加了1个新的元素,再判断其是否为record某一行的子集 若是则其为 候选集中的一项
                    Iterator itr = hsSet.iterator();
                    List tempList = new ArrayList();
                    while (itr.hasNext()) {
                        String Item = (String) itr.next();
                        tempList.add(Item);
                    }
                    Collections.sort(tempList);
                    nextCandidateItemset.add(tempList);
                }

            }

        }
        return nextCandidateItemset;
    }

    private static boolean isSubSet(HashSet hsSet,
            List> result) {
        // hsSet转换成List

        List tempList = new ArrayList();

        Iterator itr = hsSet.iterator();
        while (itr.hasNext()) {
            String Item = (String) itr.next();
            tempList.add(Item);
        }
        Collections.sort(tempList); 
        List> sublist = new ArrayList>();

        for(int i = 0; i < tempList.size(); i++){
            List temp = new ArrayList();
            for(int j = 0; j < tempList.size(); j++){
                temp.add(tempList.get(j));
            }
            temp.remove(temp.get(i));
            sublist.add(temp);

        }
        if(result.containsAll(sublist)){
            return true;
        }

        /*for (int i = 1; i < result.size(); i++) {
            List tempListRecord = new ArrayList();
            for (int j = 1; j < result.get(i).size(); j++)
                tempListRecord.add(result.get(i).get(j));
            if (tempListRecord.containsAll(tempList))
                return true;
        }*/
        return true;
    }

    private static boolean isnotHave(HashSet hsSet,
            List> nextCandidateItemset) {
        List tempList = new ArrayList();
        Iterator itr = hsSet.iterator();
        while (itr.hasNext()) {
            String Item = (String) itr.next();
            tempList.add(Item);
        }
        Collections.sort(tempList);
        if(nextCandidateItemset.contains(tempList)){
            return false;
        }
        return true;
    }

    public static boolean SaveNextRecords(String outfile, String savefile,int count) {
        //读输出文件,将符合条件的行放到hdfs,保存路径为savafile+count
        boolean finish = false;
        try {
            Configuration conf = new Configuration();


            Path rPath = new Path(savefile+"/num"+count+"frequeceitems.data");
            FileSystem rfs = rPath.getFileSystem(conf);            
            FSDataOutputStream out = rfs.create(rPath);

            Path path = new Path(outfile);          
            FileSystem fileSystem = path.getFileSystem(conf);
            FileStatus[] listFile = fileSystem.listStatus(path);
            for (int i = 0; i < listFile.length; i++){
                 FSDataInputStream in = fileSystem.open(listFile[i].getPath());
                 //FSDataInputStream in2 = fileSystem.open(listFile[i].getPath());
                 int byteRead = 0;
                 byte[] buffer = new byte[256];
                 while ((byteRead = in.read(buffer)) > 0) {
                        out.write(buffer, 0, byteRead);
                        finish = true;
                    }
                 in.close();


            }
            out.close();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        //保存之后进行连接和剪枝
        try {
            deletePath(outfile);
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }       
        return finish;
    }
    public static void deletePath(String pathStr) throws IOException{
        Configuration conf = new Configuration();
        Path path = new Path(pathStr);
        FileSystem hdfs = path.getFileSystem(conf);
        hdfs.delete(path ,true);

    }

}


上面是辅助类,接着是MapReduce主程序AprioriMapReduce.java,实现循环执行。

package myapriori;

import java.io.*;
import java.util.*;

import org.apache.hadoop.fs.Path;
import org.apache.hadoop.conf.*;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapred.*;
import org.apache.hadoop.util.*;

public class AprioriMapReduce extends Configured implements Tool {

    public static class Map extends MapReduceBase implements
            Mapper<LongWritable, Text, Text, IntWritable> {

        public List> nextrecords = null;
        public ArrayList allitems = new ArrayList();

        private final static IntWritable one = new IntWritable(1);
        private Text word = new Text();
        private static String count = null;

        public void configure(JobConf job) {

            String record = job.get("map.record.file");
            String isDirectory = job.get("map.record.isDirectory");
            count = job.get("map.record.isDirectory");

            if(!isDirectory.equals("true")){
                nextrecords = Assitance.getNextRecord(record, isDirectory);
            }


            if(nextrecords.isEmpty()||nextrecords.size()==0){
                List finish = new ArrayList();
                finish.add("null");
                nextrecords.add(finish);
            }

        }

        @Override
        public void map(LongWritable key, Text value,
                OutputCollector output, Reporter report)
                throws IOException {
            String line = value.toString().toLowerCase();
            int tcount = line.indexOf("\t");
            if(tcount >= 0){
                line = line.substring(tcount,line.length()).trim().replaceAll("\t", "").toLowerCase();
            }
            String[] dd = line.split(",");
            if(!count.equals("false")){
                for(String sd : dd){
                    List dstr = new ArrayList();
                    dstr.add(sd);
                    word = new Text(dstr.toString());
                    output.collect(word, one);
                }
            }
            else{
                List dstr = new ArrayList();
                for(String ss: dd){
                    dstr.add(ss);
                }

                for(int i = 0 ; i< nextrecords.size();i++){
                    if(dstr.containsAll(nextrecords.get(i))){
                        word = new Text(nextrecords.get(i).toString());
                        output.collect(word, one);
                    }
                }
            }




        }

    }

    public static class Combine extends MapReduceBase implements
            Reducer<Text, IntWritable, Text, IntWritable> {

        @Override
        public void reduce(Text key, Iterator values,
                OutputCollector output, Reporter report)
                throws IOException {
            int sum = 0;
            while (values.hasNext()) {
                sum += values.next().get();
            }

            output.collect(key, new IntWritable(sum));

        }

    }

    public static class Reduce extends MapReduceBase implements
            Reducer<Text, IntWritable, Text, IntWritable> {
        private static int minnum = 0;

        @Override
        public void reduce(Text key, Iterator values,
                OutputCollector output, Reporter report)
                throws IOException {
            int sum = 0;
            while (values.hasNext()) {
                sum += values.next().get();
            }
            if (sum >= minnum) {
                output.collect(key, new IntWritable(sum));
            }

        }

        public void configure(JobConf job) {
            System.out.println(minnum);
            minnum = Integer.parseInt(job.get("map.record.supportnum"));
        }

    }

    @Override
    public int run(String[] args) throws Exception {
        JobConf conf = new JobConf(getConf(), AprioriMapReduce.class);
        conf.setJobName("apriori");

        conf.setMapperClass(Map.class);
        conf.setMapOutputKeyClass(Text.class);
        conf.setMapOutputValueClass(IntWritable.class);

        // conf.setCombinerClass(Reduce.class);

        conf.setCombinerClass(Combine.class);

        conf.setReducerClass(Reduce.class);
        conf.setOutputKeyClass(Text.class);
        conf.setOutputValueClass(Text.class);

        // conf.setInputFormat(TextInputFormat.class);
        // conf.setOutputFormat(TextOutputFormat.class);

        FileInputFormat.setInputPaths(conf, new Path(args[0]));
        FileOutputFormat.setOutputPath(conf, new Path(args[1]));
        conf.set("map.items.file", args[0]);
        conf.set("map.record.file", args[5]);
        conf.set("map.record.supportnum", args[3]);
        conf.set("map.record.isDirectory", args[4]);

        JobClient.runJob(conf);
        return 0;
    }

    public static void main(String[] args) throws Exception {
        // 获取记录条数
        int res = 0;
        /*
         * int itemsnum = 0; itemsnum = Utils.CountItemsNum(args[0],
         * args[args.length - 1]); if (itemsnum < 1) {
         * System.out.println("输入的文件有误!"); System.exit(res); } // 获取最小的支持总数
         * Float minnum = itemsnum * Float.parseFloat(args[args.length - 2]);
         */
        if(args.length<4){
            System.err.println("please ensure the args length is no less 4!");
            System.exit(res);
        }

        int count = 0;
        boolean target = true;
        String lastarg[] = new String[args.length+1];
        for (int i = 0; i < args.length; i++) {
            lastarg[i] = args[i];
        }
        while (target) {
            // 执行第一遍的mapreduce
            if (count == 0) {
                lastarg[args.length] = args[0];

            }
            else
                lastarg[args.length] = args[2]+"/num"+count+"frequeceitems.data";

            count++;
            res = ToolRunner.run(new Configuration(), new AprioriMapReduce(),
                    lastarg);
            target = Assitance.SaveNextRecords(args[1], args[2], count);

            lastarg[4] = "false";

        }
        System.exit(res);
    }

}

根据主程序的内容输入的格式为:

~> bin/hadoop jar myapriori.jar myapriori.AprioriMapReduce <hdfs/inputfile/这里是源数据文件,格式见下文示例> <hdfs/output/暂存的输出文件路径> <hdfs/savefile/保存频繁项的文件路径> <min support num/最小支持数量> true <-这个true必不可少 

在执行时上面的<>是不需要的哈,为了方面理解才加的。
输入文件的格式如下:

id goodsid


id1 苹果,橘子,香蕉,可乐
id2 苹果,草莓
id3 香蕉,酸奶
id4 可乐,酸奶
id5 橘子,可乐
… …


id与项集之间以tab键隔开,项集之间以逗号隔开。
另外有个很好用的在结构化数据库中直接检索到这样的数据的方法:

select id, concat_ws(',',collect_set(**goodid**)) 
from xxx

这个函数可以直接实现数据库的行列转换,将多行归为一行。

以上就是整个的实现过程和辅助策略,候选项集大于1000时会对内存要求较高,因此在使用时最好谨慎筛选候选集的数量,执行结果也会比较快啦~

权限声明:未经允许不得转载。

你可能感兴趣的:(机器学习,数据挖掘,hadoop)