FPGrowth 实现

在关联规则挖掘领域最经典的算法法是Apriori,其致命的缺点是需要多次扫描事务数据库。于是人们提出了各种裁剪(prune)数据集的方法以减少I/O开支,韩嘉炜老师的FP-Tree算法就是其中非常高效的一种。

支持度和置信度

严格地说Apriori和FP-Tree都是寻找频繁项集的算法,频繁项集就是所谓的“支持度”比较高的项集,下面解释一下支持度和置信度的概念。

设事务数据库为:

复制代码
A  E  F  G

A  F  G

A  B  E  F  G

E  F  G
复制代码

则{A,F,G}的支持度数为3,支持度为3/4。

{F,G}的支持度数为4,支持度为4/4。

{A}的支持度数为3,支持度为3/4。

{F,G}=>{A}的置信度为:{A,F,G}的支持度数 除以 {F,G}的支持度数,即3/4

{A}=>{F,G}的置信度为:{A,F,G}的支持度数 除以 {A}的支持度数,即3/3

强关联规则挖掘是在满足一定支持度的情况下寻找置信度达到阈值的所有模式。

FP-Tree算法

我们举个例子来详细讲解FP-Tree算法的完整实现。

事务数据库如下,一行表示一条购物记录:

复制代码
牛奶,鸡蛋,面包,薯片

鸡蛋,爆米花,薯片,啤酒

鸡蛋,面包,薯片

牛奶,鸡蛋,面包,爆米花,薯片,啤酒

牛奶,面包,啤酒

鸡蛋,面包,啤酒

牛奶,面包,薯片

牛奶,鸡蛋,面包,黄油,薯片

牛奶,鸡蛋,黄油,薯片
复制代码

我们的目的是要找出哪些商品总是相伴出现的,比如人们买薯片的时候通常也会买鸡蛋,则[薯片,鸡蛋]就是一条频繁模式(frequent pattern)。

FP-Tree算法第一步:扫描事务数据库,每项商品按频数递减排序,并删除频数小于最小支持度MinSup的商品。(第一次扫描数据库)

薯片:7鸡蛋:7面包:7牛奶:6啤酒:4                       (这里我们令MinSup=3)

以上结果就是频繁1项集,记为F1。

第二步:对于每一条购买记录,按照F1中的顺序重新排序。(第二次也是最后一次扫描数据库)

复制代码
薯片,鸡蛋,面包,牛奶

薯片,鸡蛋,啤酒

薯片,鸡蛋,面包

薯片,鸡蛋,面包,牛奶,啤酒

面包,牛奶,啤酒

鸡蛋,面包,啤酒

薯片,面包,牛奶

薯片,鸡蛋,面包,牛奶

薯片,鸡蛋,牛奶
复制代码

第三步:把第二步得到的各条记录插入到FP-Tree中。刚开始时后缀模式为空。

插入每一条(薯片,鸡蛋,面包,牛奶)之后

FPGrowth 实现_第1张图片

插入第二条记录(薯片,鸡蛋,啤酒)

FPGrowth 实现_第2张图片

插入第三条记录(面包,牛奶,啤酒)

FPGrowth 实现_第3张图片

估计你也知道怎么插了,最终生成的FP-Tree是:

FPGrowth 实现_第4张图片

上图中左边的那一叫做表头项,树中相同名称的节点要链接起来,链表的第一个元素就是表头项里的元素。

如果FP-Tree为空(只含一个虚的root节点),则FP-Growth函数返回。

此时输出表头项的每一项+postModel,支持度为表头项中对应项的计数。

第四步:从FP-Tree中找出频繁项。

遍历表头项中的每一项(我们拿“牛奶:6”为例),对于各项都执行以下(1)到(5)的操作:

(1)从FP-Tree中找到所有的“牛奶”节点,向上遍历它的祖先节点,得到4条路径:

复制代码
薯片:7,鸡蛋:6,牛奶:1

薯片:7,鸡蛋:6,面包:4,牛奶:3

薯片:7,面包:1,牛奶:1

面包:1,牛奶:1
复制代码

对于每一条路径上的节点,其count都设置为牛奶的count

复制代码
薯片:1,鸡蛋:1,牛奶:1

薯片:3,鸡蛋:3,面包:3,牛奶:3

薯片:1,面包:1,牛奶:1

面包:1,牛奶:1
复制代码

因为每一项末尾都是牛奶,可以把牛奶去掉,得到条件模式基(Conditional Pattern Base,CPB),此时的后缀模式是:(牛奶)。

复制代码
薯片:1,鸡蛋:1

薯片:3,鸡蛋:3,面包:3

薯片:1,面包:1

面包:1
复制代码

(2)我们把上面的结果当作原始的事务数据库,返回到第3步,递归迭代运行。

没讲清楚,你可以参考这篇博客,直接看核心代码吧:

[java] view plain copy print ?
  1. public void FPGrowth(List<List<String>> transRecords,  
  2.         List<String> postPattern,Context context) throws IOException, InterruptedException {  
  3.     // 构建项头表,同时也是频繁1项集  
  4.     ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);  
  5.     // 构建FP-Tree  
  6.     TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);  
  7.     // 如果FP-Tree为空则返回  
  8.     if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)  
  9.         return;  
  10.     //输出项头表的每一项+postPattern  
  11.     if(postPattern!=null){  
  12.         for (TreeNode header : HeaderTable) {  
  13.             String outStr=header.getName();  
  14.             int count=header.getCount();  
  15.             for (String ele : postPattern)  
  16.                 outStr+="\t" + ele;  
  17.             context.write(new IntWritable(count), new Text(outStr));  
  18.         }  
  19.     }  
  20.     // 找到项头表的每一项的条件模式基,进入递归迭代  
  21.     for (TreeNode header : HeaderTable) {  
  22.         // 后缀模式增加一项  
  23.         List<String> newPostPattern = new LinkedList<String>();  
  24.         newPostPattern.add(header.getName());  
  25.         if (postPattern != null)  
  26.             newPostPattern.addAll(postPattern);  
  27.         // 寻找header的条件模式基CPB,放入newTransRecords中  
  28.         List<List<String>> newTransRecords = new LinkedList<List<String>>();  
  29.         TreeNode backnode = header.getNextHomonym();  
  30.         while (backnode != null) {  
  31.             int counter = backnode.getCount();  
  32.             List<String> prenodes = new ArrayList<String>();  
  33.             TreeNode parent = backnode;  
  34.             // 遍历backnode的祖先节点,放到prenodes中  
  35.             while ((parent = parent.getParent()).getName() != null) {  
  36.                 prenodes.add(parent.getName());  
  37.             }  
  38.             while (counter-- > 0) {  
  39.                 newTransRecords.add(prenodes);  
  40.             }  
  41.             backnode = backnode.getNextHomonym();  
  42.         }  
  43.         // 递归迭代  
  44.         FPGrowth(newTransRecords, newPostPattern,context);  
  45.     }  
  46. }  
public void FPGrowth(List<List<String>> transRecords,
        List<String> postPattern,Context context) throws IOException, InterruptedException {
    // 构建项头表,同时也是频繁1项集
    ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);
    // 构建FP-Tree
    TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);
    // 如果FP-Tree为空则返回
    if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)
        return;
    //输出项头表的每一项+postPattern
    if(postPattern!=null){
        for (TreeNode header : HeaderTable) {
            String outStr=header.getName();
            int count=header.getCount();
            for (String ele : postPattern)
                outStr+="\t" + ele;
            context.write(new IntWritable(count), new Text(outStr));
        }
    }
    // 找到项头表的每一项的条件模式基,进入递归迭代
    for (TreeNode header : HeaderTable) {
        // 后缀模式增加一项
        List<String> newPostPattern = new LinkedList<String>();
        newPostPattern.add(header.getName());
        if (postPattern != null)
            newPostPattern.addAll(postPattern);
        // 寻找header的条件模式基CPB,放入newTransRecords中
        List<List<String>> newTransRecords = new LinkedList<List<String>>();
        TreeNode backnode = header.getNextHomonym();
        while (backnode != null) {
            int counter = backnode.getCount();
            List<String> prenodes = new ArrayList<String>();
            TreeNode parent = backnode;
            // 遍历backnode的祖先节点,放到prenodes中
            while ((parent = parent.getParent()).getName() != null) {
                prenodes.add(parent.getName());
            }
            while (counter-- > 0) {
                newTransRecords.add(prenodes);
            }
            backnode = backnode.getNextHomonym();
        }
        // 递归迭代
        FPGrowth(newTransRecords, newPostPattern,context);
    }
}

对于FP-Tree已经是单枝的情况,就没有必要再递归调用FPGrowth了,直接输出整条路径上所有节点的各种组合+postModel就可了。例如当FP-Tree为:

FPGrowth 实现_第5张图片

我们直接输出:

3  A+postModel

3  B+postModel

3  A+B+postModel

就可以了。

如何按照上面代码里的做法,是先输出:

3  A+postModel

3  B+postModel

然后把B插入到postModel的头部,重新建立一个FP-Tree,这时Tree中只含A,于是输出

3  A+(B+postModel)

两种方法结果是一样的,但毕竟重新建立FP-Tree计算量大些。

Java实现

FP树节点定义

[java] view plain copy print ?
  1. package fptree;  
  2.     
  3. import java.util.ArrayList;  
  4. import java.util.List;  
  5.     
  6. public class TreeNode implements Comparable<TreeNode> {  
  7.     
  8.     private String name; // 节点名称  
  9.     private int count; // 计数  
  10.     private TreeNode parent; // 父节点  
  11.     private List<TreeNode> children; // 子节点  
  12.     private TreeNode nextHomonym; // 下一个同名节点  
  13.     
  14.     public TreeNode() {  
  15.     
  16.     }  
  17.     
  18.     public TreeNode(String name) {  
  19.         this.name = name;  
  20.     }  
  21.     
  22.     public String getName() {  
  23.         return name;  
  24.     }  
  25.     
  26.     public void setName(String name) {  
  27.         this.name = name;  
  28.     }  
  29.     
  30.     public int getCount() {  
  31.         return count;  
  32.     }  
  33.     
  34.     public void setCount(int count) {  
  35.         this.count = count;  
  36.     }  
  37.     
  38.     public TreeNode getParent() {  
  39.         return parent;  
  40.     }  
  41.     
  42.     public void setParent(TreeNode parent) {  
  43.         this.parent = parent;  
  44.     }  
  45.     
  46.     public List<TreeNode> getChildren() {  
  47.         return children;  
  48.     }  
  49.     
  50.     public void addChild(TreeNode child) {  
  51.         if (this.getChildren() == null) {  
  52.             List<TreeNode> list = new ArrayList<TreeNode>();  
  53.             list.add(child);  
  54.             this.setChildren(list);  
  55.         } else {  
  56.             this.getChildren().add(child);  
  57.         }  
  58.     }  
  59.     
  60.     public TreeNode findChild(String name) {  
  61.         List<TreeNode> children = this.getChildren();  
  62.         if (children != null) {  
  63.             for (TreeNode child : children) {  
  64.                 if (child.getName().equals(name)) {  
  65.                     return child;  
  66.                 }  
  67.             }  
  68.         }  
  69.         return null;  
  70.     }  
  71.     
  72.     public void setChildren(List<TreeNode> children) {  
  73.         this.children = children;  
  74.     }  
  75.     
  76.     public void printChildrenName() {  
  77.         List<TreeNode> children = this.getChildren();  
  78.         if (children != null) {  
  79.             for (TreeNode child : children) {  
  80.                 System.out.print(child.getName() + " ");  
  81.             }  
  82.         } else {  
  83.             System.out.print("null");  
  84.         }  
  85.     }  
  86.     
  87.     public TreeNode getNextHomonym() {  
  88.         return nextHomonym;  
  89.     }  
  90.     
  91.     public void setNextHomonym(TreeNode nextHomonym) {  
  92.         this.nextHomonym = nextHomonym;  
  93.     }  
  94.     
  95.     public void countIncrement(int n) {  
  96.         this.count += n;  
  97.     }  
  98.     
  99.     @Override  
  100.     public int compareTo(TreeNode arg0) {  
  101.         // TODO Auto-generated method stub  
  102.         int count0 = arg0.getCount();  
  103.         // 跟默认的比较大小相反,导致调用Arrays.sort()时是按降序排列  
  104.         return count0 - this.count;  
  105.     }  
  106. }  
package fptree;
  
import java.util.ArrayList;
import java.util.List;
  
public class TreeNode implements Comparable<TreeNode> {
  
    private String name; // 节点名称
    private int count; // 计数
    private TreeNode parent; // 父节点
    private List<TreeNode> children; // 子节点
    private TreeNode nextHomonym; // 下一个同名节点
  
    public TreeNode() {
  
    }
  
    public TreeNode(String name) {
        this.name = name;
    }
  
    public String getName() {
        return name;
    }
  
    public void setName(String name) {
        this.name = name;
    }
  
    public int getCount() {
        return count;
    }
  
    public void setCount(int count) {
        this.count = count;
    }
  
    public TreeNode getParent() {
        return parent;
    }
  
    public void setParent(TreeNode parent) {
        this.parent = parent;
    }
  
    public List<TreeNode> getChildren() {
        return children;
    }
  
    public void addChild(TreeNode child) {
        if (this.getChildren() == null) {
            List<TreeNode> list = new ArrayList<TreeNode>();
            list.add(child);
            this.setChildren(list);
        } else {
            this.getChildren().add(child);
        }
    }
  
    public TreeNode findChild(String name) {
        List<TreeNode> children = this.getChildren();
        if (children != null) {
            for (TreeNode child : children) {
                if (child.getName().equals(name)) {
                    return child;
                }
            }
        }
        return null;
    }
  
    public void setChildren(List<TreeNode> children) {
        this.children = children;
    }
  
    public void printChildrenName() {
        List<TreeNode> children = this.getChildren();
        if (children != null) {
            for (TreeNode child : children) {
                System.out.print(child.getName() + " ");
            }
        } else {
            System.out.print("null");
        }
    }
  
    public TreeNode getNextHomonym() {
        return nextHomonym;
    }
  
    public void setNextHomonym(TreeNode nextHomonym) {
        this.nextHomonym = nextHomonym;
    }
  
    public void countIncrement(int n) {
        this.count += n;
    }
  
    @Override
    public int compareTo(TreeNode arg0) {
        // TODO Auto-generated method stub
        int count0 = arg0.getCount();
        // 跟默认的比较大小相反,导致调用Arrays.sort()时是按降序排列
        return count0 - this.count;
    }
}

挖掘频繁模式
[java] view plain copy print ?
  1. package fptree;  
  2.    
  3. import java.io.BufferedReader;  
  4. import java.io.FileReader;  
  5. import java.io.IOException;  
  6. import java.util.ArrayList;  
  7. import java.util.Collections;  
  8. import java.util.Comparator;  
  9. import java.util.HashMap;  
  10. import java.util.LinkedList;  
  11. import java.util.List;  
  12. import java.util.Map;  
  13. import java.util.Map.Entry;  
  14. import java.util.Set;  
  15.    
  16. public class FPTree {  
  17.    
  18.     private int minSuport;  
  19.    
  20.     public int getMinSuport() {  
  21.         return minSuport;  
  22.     }  
  23.    
  24.     public void setMinSuport(int minSuport) {  
  25.         this.minSuport = minSuport;  
  26.     }  
  27.    
  28.     // 从若干个文件中读入Transaction Record  
  29.     public List<List<String>> readTransRocords(String... filenames) {  
  30.         List<List<String>> transaction = null;  
  31.         if (filenames.length > 0) {  
  32.             transaction = new LinkedList<List<String>>();  
  33.             for (String filename : filenames) {  
  34.                 try {  
  35.                     FileReader fr = new FileReader(filename);  
  36.                     BufferedReader br = new BufferedReader(fr);  
  37.                     try {  
  38.                         String line;  
  39.                         List<String> record;  
  40.                         while ((line = br.readLine()) != null) {  
  41.                             if(line.trim().length()>0){  
  42.                                 String str[] = line.split(",");  
  43.                                 record = new LinkedList<String>();  
  44.                                 for (String w : str)  
  45.                                     record.add(w);  
  46.                                 transaction.add(record);  
  47.                             }  
  48.                         }  
  49.                     } finally {  
  50.                         br.close();  
  51.                     }  
  52.                 } catch (IOException ex) {  
  53.                     System.out.println("Read transaction records failed."  
  54.                             + ex.getMessage());  
  55.                     System.exit(1);  
  56.                 }  
  57.             }  
  58.         }  
  59.         return transaction;  
  60.     }  
  61.    
  62.     // FP-Growth算法  
  63.     public void FPGrowth(List<List<String>> transRecords,  
  64.             List<String> postPattern) {  
  65.         // 构建项头表,同时也是频繁1项集  
  66.         ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);  
  67.         // 构建FP-Tree  
  68.         TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);  
  69.         // 如果FP-Tree为空则返回  
  70.         if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)  
  71.             return;  
  72.         //输出项头表的每一项+postPattern  
  73.         if(postPattern!=null){  
  74.             for (TreeNode header : HeaderTable) {  
  75.                 System.out.print(header.getCount() + "\t" + header.getName());  
  76.                 for (String ele : postPattern)  
  77.                     System.out.print("\t" + ele);  
  78.                 System.out.println();  
  79.             }  
  80.         }  
  81.         // 找到项头表的每一项的条件模式基,进入递归迭代  
  82.         for (TreeNode header : HeaderTable) {  
  83.             // 后缀模式增加一项  
  84.             List<String> newPostPattern = new LinkedList<String>();  
  85.             newPostPattern.add(header.getName());  
  86.             if (postPattern != null)  
  87.                 newPostPattern.addAll(postPattern);  
  88.             // 寻找header的条件模式基CPB,放入newTransRecords中  
  89.             List<List<String>> newTransRecords = new LinkedList<List<String>>();  
  90.             TreeNode backnode = header.getNextHomonym();  
  91.             while (backnode != null) {  
  92.                 int counter = backnode.getCount();  
  93.                 List<String> prenodes = new ArrayList<String>();  
  94.                 TreeNode parent = backnode;  
  95.                 // 遍历backnode的祖先节点,放到prenodes中  
  96.                 while ((parent = parent.getParent()).getName() != null) {  
  97.                     prenodes.add(parent.getName());  
  98.                 }  
  99.                 while (counter-- > 0) {  
  100.                     newTransRecords.add(prenodes);  
  101.                 }  
  102.                 backnode = backnode.getNextHomonym();  
  103.             }  
  104.             // 递归迭代  
  105.             FPGrowth(newTransRecords, newPostPattern);  
  106.         }  
  107.     }  
  108.    
  109.     // 构建项头表,同时也是频繁1项集  
  110.     public ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {  
  111.         ArrayList<TreeNode> F1 = null;  
  112.         if (transRecords.size() > 0) {  
  113.             F1 = new ArrayList<TreeNode>();  
  114.             Map<String, TreeNode> map = new HashMap<String, TreeNode>();  
  115.             // 计算事务数据库中各项的支持度  
  116.             for (List<String> record : transRecords) {  
  117.                 for (String item : record) {  
  118.                     if (!map.keySet().contains(item)) {  
  119.                         TreeNode node = new TreeNode(item);  
  120.                         node.setCount(1);  
  121.                         map.put(item, node);  
  122.                     } else {  
  123.                         map.get(item).countIncrement(1);  
  124.                     }  
  125.                 }  
  126.             }  
  127.             // 把支持度大于(或等于)minSup的项加入到F1中  
  128.             Set<String> names = map.keySet();  
  129.             for (String name : names) {  
  130.                 TreeNode tnode = map.get(name);  
  131.                 if (tnode.getCount() >= minSuport) {  
  132.                     F1.add(tnode);  
  133.                 }  
  134.             }  
  135.             Collections.sort(F1);  
  136.             return F1;  
  137.         } else {  
  138.             return null;  
  139.         }  
  140.     }  
  141.    
  142.     // 构建FP-Tree  
  143.     public TreeNode buildFPTree(List<List<String>> transRecords,  
  144.             ArrayList<TreeNode> F1) {  
  145.         TreeNode root = new TreeNode(); // 创建树的根节点  
  146.         for (List<String> transRecord : transRecords) {  
  147.             LinkedList<String> record = sortByF1(transRecord, F1);  
  148.             TreeNode subTreeRoot = root;  
  149.             TreeNode tmpRoot = null;  
  150.             if (root.getChildren() != null) {  
  151.                 while (!record.isEmpty()  
  152.                         && (tmpRoot = subTreeRoot.findChild(record.peek())) != null) {  
  153.                     tmpRoot.countIncrement(1);  
  154.                     subTreeRoot = tmpRoot;  
  155.                     record.poll();  
  156.                 }  
  157.             }  
  158.             addNodes(subTreeRoot, record, F1);  
  159.         }  
  160.         return root;  
  161.     }  
  162.    
  163.     // 把交易记录按项的频繁程序降序排列  
  164.     public LinkedList<String> sortByF1(List<String> transRecord,  
  165.             ArrayList<TreeNode> F1) {  
  166.         Map<String, Integer> map = new HashMap<String, Integer>();  
  167.         for (String item : transRecord) {  
  168.             // 由于F1已经是按降序排列的,  
  169.             for (int i = 0; i < F1.size(); i++) {  
  170.                 TreeNode tnode = F1.get(i);  
  171.                 if (tnode.getName().equals(item)) {  
  172.                     map.put(item, i);  
  173.                 }  
  174.             }  
  175.         }  
  176.         ArrayList<Entry<String, Integer>> al = new ArrayList<Entry<String, Integer>>(  
  177.                 map.entrySet());  
  178.         Collections.sort(al, new Comparator<Map.Entry<String, Integer>>() {  
  179.             @Override  
  180.             public int compare(Entry<String, Integer> arg0,  
  181.                     Entry<String, Integer> arg1) {  
  182.                 // 降序排列  
  183.                 return arg0.getValue() - arg1.getValue();  
  184.             }  
  185.         });  
  186.         LinkedList<String> rest = new LinkedList<String>();  
  187.         for (Entry<String, Integer> entry : al) {  
  188.             rest.add(entry.getKey());  
  189.         }  
  190.         return rest;  
  191.     }  
  192.    
  193.     // 把record作为ancestor的后代插入树中  
  194.     public void addNodes(TreeNode ancestor, LinkedList<String> record,  
  195.             ArrayList<TreeNode> F1) {  
  196.         if (record.size() > 0) {  
  197.             while (record.size() > 0) {  
  198.                 String item = record.poll();  
  199.                 TreeNode leafnode = new TreeNode(item);  
  200.                 leafnode.setCount(1);  
  201.                 leafnode.setParent(ancestor);  
  202.                 ancestor.addChild(leafnode);  
  203.    
  204.                 for (TreeNode f1 : F1) {  
  205.                     if (f1.getName().equals(item)) {  
  206.                         while (f1.getNextHomonym() != null) {  
  207.                             f1 = f1.getNextHomonym();  
  208.                         }  
  209.                         f1.setNextHomonym(leafnode);  
  210.                         break;  
  211.                     }  
  212.                 }  
  213.    
  214.                 addNodes(leafnode, record, F1);  
  215.             }  
  216.         }  
  217.     }  
  218.    
  219.     public static void main(String[] args) {  
  220.         FPTree fptree = new FPTree();  
  221.         fptree.setMinSuport(3);  
  222.         List<List<String>> transRecords = fptree  
  223.                 .readTransRocords("/home/orisun/test/market");  
  224.         fptree.FPGrowth(transRecords, null);  
  225.     }  
  226. }  
package fptree;
 
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
 
public class FPTree {
 
    private int minSuport;
 
    public int getMinSuport() {
        return minSuport;
    }
 
    public void setMinSuport(int minSuport) {
        this.minSuport = minSuport;
    }
 
    // 从若干个文件中读入Transaction Record
    public List<List<String>> readTransRocords(String... filenames) {
        List<List<String>> transaction = null;
        if (filenames.length > 0) {
            transaction = new LinkedList<List<String>>();
            for (String filename : filenames) {
                try {
                    FileReader fr = new FileReader(filename);
                    BufferedReader br = new BufferedReader(fr);
                    try {
                        String line;
                        List<String> record;
                        while ((line = br.readLine()) != null) {
                            if(line.trim().length()>0){
                                String str[] = line.split(",");
                                record = new LinkedList<String>();
                                for (String w : str)
                                    record.add(w);
                                transaction.add(record);
                            }
                        }
                    } finally {
                        br.close();
                    }
                } catch (IOException ex) {
                    System.out.println("Read transaction records failed."
                            + ex.getMessage());
                    System.exit(1);
                }
            }
        }
        return transaction;
    }
 
    // FP-Growth算法
    public void FPGrowth(List<List<String>> transRecords,
            List<String> postPattern) {
        // 构建项头表,同时也是频繁1项集
        ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);
        // 构建FP-Tree
        TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);
        // 如果FP-Tree为空则返回
        if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)
            return;
        //输出项头表的每一项+postPattern
        if(postPattern!=null){
            for (TreeNode header : HeaderTable) {
                System.out.print(header.getCount() + "\t" + header.getName());
                for (String ele : postPattern)
                    System.out.print("\t" + ele);
                System.out.println();
            }
        }
        // 找到项头表的每一项的条件模式基,进入递归迭代
        for (TreeNode header : HeaderTable) {
            // 后缀模式增加一项
            List<String> newPostPattern = new LinkedList<String>();
            newPostPattern.add(header.getName());
            if (postPattern != null)
                newPostPattern.addAll(postPattern);
            // 寻找header的条件模式基CPB,放入newTransRecords中
            List<List<String>> newTransRecords = new LinkedList<List<String>>();
            TreeNode backnode = header.getNextHomonym();
            while (backnode != null) {
                int counter = backnode.getCount();
                List<String> prenodes = new ArrayList<String>();
                TreeNode parent = backnode;
                // 遍历backnode的祖先节点,放到prenodes中
                while ((parent = parent.getParent()).getName() != null) {
                    prenodes.add(parent.getName());
                }
                while (counter-- > 0) {
                    newTransRecords.add(prenodes);
                }
                backnode = backnode.getNextHomonym();
            }
            // 递归迭代
            FPGrowth(newTransRecords, newPostPattern);
        }
    }
 
    // 构建项头表,同时也是频繁1项集
    public ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {
        ArrayList<TreeNode> F1 = null;
        if (transRecords.size() > 0) {
            F1 = new ArrayList<TreeNode>();
            Map<String, TreeNode> map = new HashMap<String, TreeNode>();
            // 计算事务数据库中各项的支持度
            for (List<String> record : transRecords) {
                for (String item : record) {
                    if (!map.keySet().contains(item)) {
                        TreeNode node = new TreeNode(item);
                        node.setCount(1);
                        map.put(item, node);
                    } else {
                        map.get(item).countIncrement(1);
                    }
                }
            }
            // 把支持度大于(或等于)minSup的项加入到F1中
            Set<String> names = map.keySet();
            for (String name : names) {
                TreeNode tnode = map.get(name);
                if (tnode.getCount() >= minSuport) {
                    F1.add(tnode);
                }
            }
            Collections.sort(F1);
            return F1;
        } else {
            return null;
        }
    }
 
    // 构建FP-Tree
    public TreeNode buildFPTree(List<List<String>> transRecords,
            ArrayList<TreeNode> F1) {
        TreeNode root = new TreeNode(); // 创建树的根节点
        for (List<String> transRecord : transRecords) {
            LinkedList<String> record = sortByF1(transRecord, F1);
            TreeNode subTreeRoot = root;
            TreeNode tmpRoot = null;
            if (root.getChildren() != null) {
                while (!record.isEmpty()
                        && (tmpRoot = subTreeRoot.findChild(record.peek())) != null) {
                    tmpRoot.countIncrement(1);
                    subTreeRoot = tmpRoot;
                    record.poll();
                }
            }
            addNodes(subTreeRoot, record, F1);
        }
        return root;
    }
 
    // 把交易记录按项的频繁程序降序排列
    public LinkedList<String> sortByF1(List<String> transRecord,
            ArrayList<TreeNode> F1) {
        Map<String, Integer> map = new HashMap<String, Integer>();
        for (String item : transRecord) {
            // 由于F1已经是按降序排列的,
            for (int i = 0; i < F1.size(); i++) {
                TreeNode tnode = F1.get(i);
                if (tnode.getName().equals(item)) {
                    map.put(item, i);
                }
            }
        }
        ArrayList<Entry<String, Integer>> al = new ArrayList<Entry<String, Integer>>(
                map.entrySet());
        Collections.sort(al, new Comparator<Map.Entry<String, Integer>>() {
            @Override
            public int compare(Entry<String, Integer> arg0,
                    Entry<String, Integer> arg1) {
                // 降序排列
                return arg0.getValue() - arg1.getValue();
            }
        });
        LinkedList<String> rest = new LinkedList<String>();
        for (Entry<String, Integer> entry : al) {
            rest.add(entry.getKey());
        }
        return rest;
    }
 
    // 把record作为ancestor的后代插入树中
    public void addNodes(TreeNode ancestor, LinkedList<String> record,
            ArrayList<TreeNode> F1) {
        if (record.size() > 0) {
            while (record.size() > 0) {
                String item = record.poll();
                TreeNode leafnode = new TreeNode(item);
                leafnode.setCount(1);
                leafnode.setParent(ancestor);
                ancestor.addChild(leafnode);
 
                for (TreeNode f1 : F1) {
                    if (f1.getName().equals(item)) {
                        while (f1.getNextHomonym() != null) {
                            f1 = f1.getNextHomonym();
                        }
                        f1.setNextHomonym(leafnode);
                        break;
                    }
                }
 
                addNodes(leafnode, record, F1);
            }
        }
    }
 
    public static void main(String[] args) {
        FPTree fptree = new FPTree();
        fptree.setMinSuport(3);
        List<List<String>> transRecords = fptree
                .readTransRocords("/home/orisun/test/market");
        fptree.FPGrowth(transRecords, null);
    }
}

输入文件

复制代码
牛奶,鸡蛋,面包,薯片
鸡蛋,爆米花,薯片,啤酒
鸡蛋,面包,薯片
牛奶,鸡蛋,面包,爆米花,薯片,啤酒
牛奶,面包,啤酒
鸡蛋,面包,啤酒
牛奶,面包,薯片
牛奶,鸡蛋,面包,黄油,薯片
牛奶,鸡蛋,黄油,薯片
复制代码

输出

复制代码
6    薯片    鸡蛋
5    薯片    面包
5    鸡蛋    面包
4    薯片    鸡蛋    面包
5    薯片    牛奶
5    面包    牛奶
4    鸡蛋    牛奶
4    薯片    面包    牛奶
4    薯片    鸡蛋    牛奶
3    面包    鸡蛋    牛奶
3    薯片    面包    鸡蛋    牛奶
3    鸡蛋    啤酒
3    面包    啤酒
复制代码

用Hadoop来实现

在上面的代码我们把整个事务数据库放在一个List<List<String>>里面传给FPGrowth,在实际中这是不可取的,因为内存不可能容下整个事务数据库,我们可能需要从关系关系数据库中一条一条地读入来建立FP-Tree。但无论如何 FP-Tree是肯定需要放在内存中的,但内存如果容不下怎么办?另外FPGrowth仍然是非常耗时的,你想提高速度怎么办?解决办法:分而治之,并行计算。

我们把原始事务数据库分成N部分,在N个节点上并行地进行FPGrowth挖掘,最后把关联规则汇总到一起就可以了。关键问题是怎么“划分”才会不遗露任何一条关联规则呢?参见这篇博客。这里为了达到并行计算的目的,采用了一种“冗余”的划分方法,即各部分的并集大于原来的集合。这种方法最终求出来的关联规则也是有冗余的,比如在节点1上得到一条规则(6:啤酒,尿布),在节点2上得到一条规则(3:尿布,啤酒),显然节点2上的这条规则是冗余的,需要采用后续步骤把冗余的规则去掉。

代码:

Record.java

[java] view plain copy print ?
  1. package fptree;  
  2.    
  3. import java.io.DataInput;  
  4. import java.io.DataOutput;  
  5. import java.io.IOException;  
  6. import java.util.Collections;  
  7. import java.util.LinkedList;  
  8.    
  9. import org.apache.hadoop.io.WritableComparable;  
  10.    
  11. public class Record implements WritableComparable<Record>{  
  12.        
  13.     LinkedList<String> list;  
  14.        
  15.     public Record(){  
  16.         list=new LinkedList<String>();  
  17.     }  
  18.        
  19.     public Record(String[] arr){  
  20.         list=new LinkedList<String>();  
  21.         for(int i=0;i<arr.length;i++)  
  22.             list.add(arr[i]);  
  23.     }  
  24.        
  25.     @Override  
  26.     public String toString(){  
  27.         String str=list.get(0);  
  28.         for(int i=1;i<list.size();i++)  
  29.             str+="\t"+list.get(i);  
  30.         return str;  
  31.     }  
  32.    
  33.     @Override  
  34.     public void readFields(DataInput in) throws IOException {  
  35.         list.clear();  
  36.         String line=in.readUTF();  
  37.         String []arr=line.split("\\s+");  
  38.         for(int i=0;i<arr.length;i++)  
  39.             list.add(arr[i]);  
  40.     }  
  41.    
  42.     @Override  
  43.     public void write(DataOutput out) throws IOException {  
  44.         out.writeUTF(this.toString());  
  45.     }  
  46.    
  47.     @Override  
  48.     public int compareTo(Record obj) {  
  49.         Collections.sort(list);  
  50.         Collections.sort(obj.list);  
  51.         return this.toString().compareTo(obj.toString());  
  52.     }  
  53.    
  54. }  
package fptree;
 
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Collections;
import java.util.LinkedList;
 
import org.apache.hadoop.io.WritableComparable;
 
public class Record implements WritableComparable<Record>{
     
    LinkedList<String> list;
     
    public Record(){
        list=new LinkedList<String>();
    }
     
    public Record(String[] arr){
        list=new LinkedList<String>();
        for(int i=0;i<arr.length;i++)
            list.add(arr[i]);
    }
     
    @Override
    public String toString(){
        String str=list.get(0);
        for(int i=1;i<list.size();i++)
            str+="\t"+list.get(i);
        return str;
    }
 
    @Override
    public void readFields(DataInput in) throws IOException {
        list.clear();
        String line=in.readUTF();
        String []arr=line.split("\\s+");
        for(int i=0;i<arr.length;i++)
            list.add(arr[i]);
    }
 
    @Override
    public void write(DataOutput out) throws IOException {
        out.writeUTF(this.toString());
    }
 
    @Override
    public int compareTo(Record obj) {
        Collections.sort(list);
        Collections.sort(obj.list);
        return this.toString().compareTo(obj.toString());
    }
 
}

DC_FPTree.java

[java] view plain copy print ?
  1. package fptree;  
  2.    
  3. import java.io.BufferedReader;  
  4. import java.io.IOException;  
  5. import java.io.InputStreamReader;  
  6. import java.util.ArrayList;  
  7. import java.util.BitSet;  
  8. import java.util.Collections;  
  9. import java.util.Comparator;  
  10. import java.util.HashMap;  
  11. import java.util.LinkedList;  
  12. import java.util.List;  
  13. import java.util.Map;  
  14. import java.util.Map.Entry;  
  15. import java.util.Set;  
  16.    
  17. import org.apache.hadoop.conf.Configuration;  
  18. import org.apache.hadoop.conf.Configured;  
  19. import org.apache.hadoop.fs.FSDataInputStream;  
  20. import org.apache.hadoop.fs.FileSystem;  
  21. import org.apache.hadoop.fs.Path;  
  22. import org.apache.hadoop.io.IntWritable;  
  23. import org.apache.hadoop.io.LongWritable;  
  24. import org.apache.hadoop.io.Text;  
  25. import org.apache.hadoop.mapreduce.Job;  
  26. import org.apache.hadoop.mapreduce.Mapper;  
  27. import org.apache.hadoop.mapreduce.Reducer;  
  28. import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;  
  29. import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;  
  30. import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;  
  31. import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;  
  32. import org.apache.hadoop.util.Tool;  
  33. import org.apache.hadoop.util.ToolRunner;  
  34.    
  35. public class DC_FPTree extends Configured implements Tool {  
  36.    
  37.     private static final int GroupNum = 10;  
  38.     private static final int minSuport=6;  
  39.    
  40.     public static class GroupMapper extends  
  41.             Mapper<LongWritable, Text, IntWritable, Record> {  
  42.         List<String> freq = new LinkedList<String>(); // 频繁1项集  
  43.         List<List<String>> freq_group = new LinkedList<List<String>>(); // 分组后的频繁1项集  
  44.    
  45.         @Override  
  46.         public void setup(Context context) throws IOException {  
  47.             // 从文件读入频繁1项集  
  48.             FileSystem fs = FileSystem.get(context.getConfiguration());  
  49.             Path freqFile = new Path("/user/orisun/input/F1");  
  50.             FSDataInputStream in = fs.open(freqFile);  
  51.             InputStreamReader isr = new InputStreamReader(in);  
  52.             BufferedReader br = new BufferedReader(isr);  
  53.             try {  
  54.                 String line;  
  55.                 while ((line = br.readLine()) != null) {  
  56.                     String[] str = line.split("\\s+");  
  57.                     String word = str[0];  
  58.                     freq.add(word);  
  59.                 }  
  60.             } finally {  
  61.                 br.close();  
  62.             }  
  63.             // 对频繁1项集进行分组  
  64.             Collections.shuffle(freq); // 打乱顺序  
  65.             int cap = freq.size() / GroupNum; // 每段分为一组  
  66.             for (int i = 0; i < GroupNum; i++) {  
  67.                 List<String> list = new LinkedList<String>();  
  68.                 for (int j = 0; j < cap; j++) {  
  69.                     list.add(freq.get(i * cap + j));  
  70.                 }  
  71.                 freq_group.add(list);  
  72.             }  
  73.             int remainder = freq.size() % GroupNum;  
  74.             int base = GroupNum * cap;  
  75.             for (int i = 0; i < remainder; i++) {  
  76.                 freq_group.get(i).add(freq.get(base + i));  
  77.             }  
  78.         }  
  79.    
  80.         @Override  
  81.         public void map(LongWritable key, Text value, Context context)  
  82.                 throws IOException, InterruptedException {  
  83.             String[] arr = value.toString().split("\\s+");  
  84.             Record record = new Record(arr);  
  85.             LinkedList<String> list = record.list;  
  86.             BitSet bs=new BitSet(freq_group.size());  
  87.             bs.clear();  
  88.             while (record.list.size() > 0) {  
  89.                 String item = list.peekLast(); // 取出record的最后一项  
  90.                 int i=0;  
  91.                 for (; i < freq_group.size(); i++) {  
  92.                     if(bs.get(i))  
  93.                         continue;  
  94.                     if (freq_group.get(i).contains(item)) {  
  95.                         bs.set(i);  
  96.                         break;  
  97.                     }  
  98.                 }  
  99.                 if(i<freq_group.size()){     //找到了  
  100.                     context.write(new IntWritable(i), record);    
  101.                 }  
  102.                 record.list.pollLast();  
  103.             }  
  104.         }  
  105.     }  
  106.        
  107.     public static class FPReducer extends Reducer<IntWritable,Record,IntWritable,Text>{  
  108.         public void reduce(IntWritable key,Iterable<Record> values,Context context)throws IOException,InterruptedException{  
  109.             List<List<String>> trans=new LinkedList<List<String>>();  
  110.             while(values.iterator().hasNext()){  
  111.                 Record record=values.iterator().next();  
  112.                 LinkedList<String> list=new LinkedList<String>();  
  113.                 for(String ele:record.list)  
  114.                     list.add(ele);  
  115.                 trans.add(list);  
  116.             }  
  117.             FPGrowth(trans, null,context);  
  118.         }  
  119.         // FP-Growth算法  
  120.     public void FPGrowth(List<List<String>> transRecords,  
  121.             List<String> postPattern,Context context) throws IOException, InterruptedException {  
  122.         // 构建项头表,同时也是频繁1项集  
  123.         ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);  
  124.         // 构建FP-Tree  
  125.         TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);  
  126.         // 如果FP-Tree为空则返回  
  127.         if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)  
  128.             return;  
  129.         //输出项头表的每一项+postPattern  
  130.         if(postPattern!=null){  
  131.             for (TreeNode header : HeaderTable) {  
  132.                 String outStr=header.getName();  
  133.                 int count=header.getCount();  
  134.                 for (String ele : postPattern)  
  135.                     outStr+="\t" + ele;  
  136.                 context.write(new IntWritable(count), new Text(outStr));  
  137.             }  
  138.         }  
  139.         // 找到项头表的每一项的条件模式基,进入递归迭代  
  140.         for (TreeNode header : HeaderTable) {  
  141.             // 后缀模式增加一项  
  142.             List<String> newPostPattern = new LinkedList<String>();  
  143.             newPostPattern.add(header.getName());  
  144.             if (postPattern != null)  
  145.                 newPostPattern.addAll(postPattern);  
  146.             // 寻找header的条件模式基CPB,放入newTransRecords中  
  147.             List<List<String>> newTransRecords = new LinkedList<List<String>>();  
  148.             TreeNode backnode = header.getNextHomonym();  
  149.             while (backnode != null) {  
  150.                 int counter = backnode.getCount();  
  151.                 List<String> prenodes = new ArrayList<String>();  
  152.                 TreeNode parent = backnode;  
  153.                 // 遍历backnode的祖先节点,放到prenodes中  
  154.                 while ((parent = parent.getParent()).getName() != null) {  
  155.                     prenodes.add(parent.getName());  
  156.                 }  
  157.                 while (counter-- > 0) {  
  158.                     newTransRecords.add(prenodes);  
  159.                 }  
  160.                 backnode = backnode.getNextHomonym();  
  161.             }  
  162.             // 递归迭代  
  163.             FPGrowth(newTransRecords, newPostPattern,context);  
  164.         }  
  165.     }  
  166.    
  167.         // 构建项头表,同时也是频繁1项集  
  168.         public ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {  
  169.             ArrayList<TreeNode> F1 = null;  
  170.             if (transRecords.size() > 0) {  
  171.                 F1 = new ArrayList<TreeNode>();  
  172.                 Map<String, TreeNode> map = new HashMap<String, TreeNode>();  
  173.                 // 计算事务数据库中各项的支持度  
  174.                 for (List<String> record : transRecords) {  
  175.                     for (String item : record) {  
  176.                         if (!map.keySet().contains(item)) {  
  177.                             TreeNode node = new TreeNode(item);  
  178.                             node.setCount(1);  
  179.                             map.put(item, node);  
  180.                         } else {  
  181.                             map.get(item).countIncrement(1);  
  182.                         }  
  183.                     }  
  184.                 }  
  185.                 // 把支持度大于(或等于)minSup的项加入到F1中  
  186.                 Set<String> names = map.keySet();  
  187.                 for (String name : names) {  
  188.                     TreeNode tnode = map.get(name);  
  189.                     if (tnode.getCount() >= minSuport) {  
  190.                         F1.add(tnode);  
  191.                     }  
  192.                 }  
  193.                 Collections.sort(F1);  
  194.                 return F1;  
  195.             } else {  
  196.                 return null;  
  197.             }  
  198.         }  
  199.    
  200.         // 构建FP-Tree  
  201.         public TreeNode buildFPTree(List<List<String>> transRecords,  
  202.                 ArrayList<TreeNode> F1) {  
  203.             TreeNode root = new TreeNode(); // 创建树的根节点  
  204.             for (List<String> transRecord : transRecords) {  
  205.                 LinkedList<String> record = sortByF1(transRecord, F1);  
  206.                 TreeNode subTreeRoot = root;  
  207.                 TreeNode tmpRoot = null;  
  208.                 if (root.getChildren() != null) {  
  209.                     while (!record.isEmpty()  
  210.                             && (tmpRoot = subTreeRoot.findChild(record.peek())) != null) {  
  211.                         tmpRoot.countIncrement(1);  
  212.                         subTreeRoot = tmpRoot;  
  213.                         record.poll();  
  214.                     }  
  215.                 }  
  216.                 addNodes(subTreeRoot, record, F1);  
  217.             }  
  218.             return root;  
  219.         }  
  220.    
  221.         // 把交易记录按项的频繁程序降序排列  
  222.         public LinkedList<String> sortByF1(List<String> transRecord,  
  223.                 ArrayList<TreeNode> F1) {  
  224.             Map<String, Integer> map = new HashMap<String, Integer>();  
  225.             for (String item : transRecord) {  
  226.                 // 由于F1已经是按降序排列的,  
  227.                 for (int i = 0; i < F1.size(); i++) {  
  228.                     TreeNode tnode = F1.get(i);  
  229.                     if (tnode.getName().equals(item)) {  
  230.                         map.put(item, i);  
  231.                     }  
  232.                 }  
  233.             }  
  234.             ArrayList<Entry<String, Integer>> al = new ArrayList<Entry<String, Integer>>(  
  235.                     map.entrySet());  
  236.             Collections.sort(al, new Comparator<Map.Entry<String, Integer>>() {  
  237.                 @Override  
  238.                 public int compare(Entry<String, Integer> arg0,  
  239.                         Entry<String, Integer> arg1) {  
  240.                     // 降序排列  
  241.                     return arg0.getValue() - arg1.getValue();  
  242.                 }  
  243.             });  
  244.             LinkedList<String> rest = new LinkedList<String>();  
  245.             for (Entry<String, Integer> entry : al) {  
  246.                 rest.add(entry.getKey());  
  247.             }  
  248.             return rest;  
  249.         }  
  250.    
  251.         // 把record作为ancestor的后代插入树中  
  252.         public void addNodes(TreeNode ancestor, LinkedList<String> record,  
  253.                 ArrayList<TreeNode> F1) {  
  254.             if (record.size() > 0) {  
  255.                 while (record.size() > 0) {  
  256.                     String item = record.poll();  
  257.                     TreeNode leafnode = new TreeNode(item);  
  258.                     leafnode.setCount(1);  
  259.                     leafnode.setParent(ancestor);  
  260.                     ancestor.addChild(leafnode);  
  261.    
  262.                     for (TreeNode f1 : F1) {  
  263.                         if (f1.getName().equals(item)) {  
  264.                             while (f1.getNextHomonym() != null) {  
  265.                                 f1 = f1.getNextHomonym();  
  266.                             }  
  267.                             f1.setNextHomonym(leafnode);  
  268.                             break;  
  269.                         }  
  270.                     }  
  271.    
  272.                     addNodes(leafnode, record, F1);  
  273.                 }  
  274.             }  
  275.         }  
  276.     }  
  277.        
  278.     public static class InverseMapper extends  
  279.             Mapper<LongWritable, Text, Record, IntWritable> {  
  280.         @Override  
  281.         public void map(LongWritable key, Text value, Context context)  
  282.                 throws IOException, InterruptedException {  
  283.             String []arr=value.toString().split("\\s+");  
  284.             int count=Integer.parseInt(arr[0]);  
  285.             Record record=new Record();  
  286.             for(int i=1;i<arr.length;i++){  
  287.                 record.list.add(arr[i]);  
  288.             }  
  289.             context.write(record, new IntWritable(count));  
  290.         }  
  291.     }  
  292.        
  293.     public static class MaxReducer extends Reducer<Record,IntWritable,IntWritable,Record>{  
  294.         public void reduce(Record key,Iterable<IntWritable> values,Context context)throws IOException,InterruptedException{  
  295.             int max=-1;  
  296.             for(IntWritable value:values){  
  297.                 int i=value.get();  
  298.                 if(i>max)  
  299.                     max=i;  
  300.             }  
  301.             context.write(new IntWritable(max), key);  
  302.         }  
  303.     }  
  304.    
  305.    
  306.     @Override  
  307.     public int run(String[] arg0) throws Exception {  
  308.         Configuration conf=getConf();  
  309.         conf.set("mapred.task.timeout""6000000");  
  310.         Job job=new Job(conf);  
  311.         job.setJarByClass(DC_FPTree.class);  
  312.         FileSystem fs=FileSystem.get(getConf());  
  313.            
  314.         FileInputFormat.setInputPaths(job, "/user/orisun/input/data");  
  315.         Path outDir=new Path("/user/orisun/output");  
  316.         fs.delete(outDir,true);  
  317.         FileOutputFormat.setOutputPath(job, outDir);  
  318.            
  319.         job.setMapperClass(GroupMapper.class);  
  320.         job.setReducerClass(FPReducer.class);  
  321.            
  322.         job.setInputFormatClass(TextInputFormat.class);  
  323.         job.setOutputFormatClass(TextOutputFormat.class);  
  324.         job.setMapOutputKeyClass(IntWritable.class);  
  325.         job.setMapOutputValueClass(Record.class);  
  326.         job.setOutputKeyClass(IntWritable.class);  
  327.         job.setOutputKeyClass(Text.class);  
  328.            
  329.         boolean success=job.waitForCompletion(true);  
  330.            
  331.         job=new Job(conf);  
  332.         job.setJarByClass(DC_FPTree.class);  
  333.            
  334.         FileInputFormat.setInputPaths(job, "/user/orisun/output/part-r-*");  
  335.         Path outDir2=new Path("/user/orisun/output2");  
  336.         fs.delete(outDir2,true);  
  337.         FileOutputFormat.setOutputPath(job, outDir2);  
  338.            
  339.         job.setMapperClass(InverseMapper.class);  
  340.         job.setReducerClass(MaxReducer.class);  
  341.         //job.setNumReduceTasks(0);  
  342.            
  343.         job.setInputFormatClass(TextInputFormat.class);  
  344.         job.setOutputFormatClass(TextOutputFormat.class);  
  345.         job.setMapOutputKeyClass(Record.class);  
  346.         job.setMapOutputValueClass(IntWritable.class);  
  347.         job.setOutputKeyClass(IntWritable.class);  
  348.         job.setOutputKeyClass(Record.class);  
  349.            
  350.         success |= job.waitForCompletion(true);  
  351.            
  352.         return success?0:1;  
  353.     }  
  354.    
  355.     public static void main(String[] args) throws Exception{  
  356.         int res=ToolRunner.run(new Configuration(), new DC_FPTree(), args);  
  357.         System.exit(res);  
  358.     }  
  359. }  
package fptree;
 
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
 
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
 
public class DC_FPTree extends Configured implements Tool {
 
    private static final int GroupNum = 10;
    private static final int minSuport=6;
 
    public static class GroupMapper extends
            Mapper<LongWritable, Text, IntWritable, Record> {
        List<String> freq = new LinkedList<String>(); // 频繁1项集
        List<List<String>> freq_group = new LinkedList<List<String>>(); // 分组后的频繁1项集
 
        @Override
        public void setup(Context context) throws IOException {
            // 从文件读入频繁1项集
            FileSystem fs = FileSystem.get(context.getConfiguration());
            Path freqFile = new Path("/user/orisun/input/F1");
            FSDataInputStream in = fs.open(freqFile);
            InputStreamReader isr = new InputStreamReader(in);
            BufferedReader br = new BufferedReader(isr);
            try {
                String line;
                while ((line = br.readLine()) != null) {
                    String[] str = line.split("\\s+");
                    String word = str[0];
                    freq.add(word);
                }
            } finally {
                br.close();
            }
            // 对频繁1项集进行分组
            Collections.shuffle(freq); // 打乱顺序
            int cap = freq.size() / GroupNum; // 每段分为一组
            for (int i = 0; i < GroupNum; i++) {
                List<String> list = new LinkedList<String>();
                for (int j = 0; j < cap; j++) {
                    list.add(freq.get(i * cap + j));
                }
                freq_group.add(list);
            }
            int remainder = freq.size() % GroupNum;
            int base = GroupNum * cap;
            for (int i = 0; i < remainder; i++) {
                freq_group.get(i).add(freq.get(base + i));
            }
        }
 
        @Override
        public void map(LongWritable key, Text value, Context context)
                throws IOException, InterruptedException {
            String[] arr = value.toString().split("\\s+");
            Record record = new Record(arr);
            LinkedList<String> list = record.list;
            BitSet bs=new BitSet(freq_group.size());
            bs.clear();
            while (record.list.size() > 0) {
                String item = list.peekLast(); // 取出record的最后一项
                int i=0;
                for (; i < freq_group.size(); i++) {
                    if(bs.get(i))
                        continue;
                    if (freq_group.get(i).contains(item)) {
                        bs.set(i);
                        break;
                    }
                }
                if(i<freq_group.size()){     //找到了
                    context.write(new IntWritable(i), record);  
                }
                record.list.pollLast();
            }
        }
    }
     
    public static class FPReducer extends Reducer<IntWritable,Record,IntWritable,Text>{
        public void reduce(IntWritable key,Iterable<Record> values,Context context)throws IOException,InterruptedException{
            List<List<String>> trans=new LinkedList<List<String>>();
            while(values.iterator().hasNext()){
                Record record=values.iterator().next();
                LinkedList<String> list=new LinkedList<String>();
                for(String ele:record.list)
                    list.add(ele);
                trans.add(list);
            }
            FPGrowth(trans, null,context);
        }
        // FP-Growth算法
    public void FPGrowth(List<List<String>> transRecords,
            List<String> postPattern,Context context) throws IOException, InterruptedException {
        // 构建项头表,同时也是频繁1项集
        ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);
        // 构建FP-Tree
        TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);
        // 如果FP-Tree为空则返回
        if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0)
            return;
        //输出项头表的每一项+postPattern
        if(postPattern!=null){
            for (TreeNode header : HeaderTable) {
                String outStr=header.getName();
                int count=header.getCount();
                for (String ele : postPattern)
                    outStr+="\t" + ele;
                context.write(new IntWritable(count), new Text(outStr));
            }
        }
        // 找到项头表的每一项的条件模式基,进入递归迭代
        for (TreeNode header : HeaderTable) {
            // 后缀模式增加一项
            List<String> newPostPattern = new LinkedList<String>();
            newPostPattern.add(header.getName());
            if (postPattern != null)
                newPostPattern.addAll(postPattern);
            // 寻找header的条件模式基CPB,放入newTransRecords中
            List<List<String>> newTransRecords = new LinkedList<List<String>>();
            TreeNode backnode = header.getNextHomonym();
            while (backnode != null) {
                int counter = backnode.getCount();
                List<String> prenodes = new ArrayList<String>();
                TreeNode parent = backnode;
                // 遍历backnode的祖先节点,放到prenodes中
                while ((parent = parent.getParent()).getName() != null) {
                    prenodes.add(parent.getName());
                }
                while (counter-- > 0) {
                    newTransRecords.add(prenodes);
                }
                backnode = backnode.getNextHomonym();
            }
            // 递归迭代
            FPGrowth(newTransRecords, newPostPattern,context);
        }
    }
 
        // 构建项头表,同时也是频繁1项集
        public ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {
            ArrayList<TreeNode> F1 = null;
            if (transRecords.size() > 0) {
                F1 = new ArrayList<TreeNode>();
                Map<String, TreeNode> map = new HashMap<String, TreeNode>();
                // 计算事务数据库中各项的支持度
                for (List<String> record : transRecords) {
                    for (String item : record) {
                        if (!map.keySet().contains(item)) {
                            TreeNode node = new TreeNode(item);
                            node.setCount(1);
                            map.put(item, node);
                        } else {
                            map.get(item).countIncrement(1);
                        }
                    }
                }
                // 把支持度大于(或等于)minSup的项加入到F1中
                Set<String> names = map.keySet();
                for (String name : names) {
                    TreeNode tnode = map.get(name);
                    if (tnode.getCount() >= minSuport) {
                        F1.add(tnode);
                    }
                }
                Collections.sort(F1);
                return F1;
            } else {
                return null;
            }
        }
 
        // 构建FP-Tree
        public TreeNode buildFPTree(List<List<String>> transRecords,
                ArrayList<TreeNode> F1) {
            TreeNode root = new TreeNode(); // 创建树的根节点
            for (List<String> transRecord : transRecords) {
                LinkedList<String> record = sortByF1(transRecord, F1);
                TreeNode subTreeRoot = root;
                TreeNode tmpRoot = null;
                if (root.getChildren() != null) {
                    while (!record.isEmpty()
                            && (tmpRoot = subTreeRoot.findChild(record.peek())) != null) {
                        tmpRoot.countIncrement(1);
                        subTreeRoot = tmpRoot;
                        record.poll();
                    }
                }
                addNodes(subTreeRoot, record, F1);
            }
            return root;
        }
 
        // 把交易记录按项的频繁程序降序排列
        public LinkedList<String> sortByF1(List<String> transRecord,
                ArrayList<TreeNode> F1) {
            Map<String, Integer> map = new HashMap<String, Integer>();
            for (String item : transRecord) {
                // 由于F1已经是按降序排列的,
                for (int i = 0; i < F1.size(); i++) {
                    TreeNode tnode = F1.get(i);
                    if (tnode.getName().equals(item)) {
                        map.put(item, i);
                    }
                }
            }
            ArrayList<Entry<String, Integer>> al = new ArrayList<Entry<String, Integer>>(
                    map.entrySet());
            Collections.sort(al, new Comparator<Map.Entry<String, Integer>>() {
                @Override
                public int compare(Entry<String, Integer> arg0,
                        Entry<String, Integer> arg1) {
                    // 降序排列
                    return arg0.getValue() - arg1.getValue();
                }
            });
            LinkedList<String> rest = new LinkedList<String>();
            for (Entry<String, Integer> entry : al) {
                rest.add(entry.getKey());
            }
            return rest;
        }
 
        // 把record作为ancestor的后代插入树中
        public void addNodes(TreeNode ancestor, LinkedList<String> record,
                ArrayList<TreeNode> F1) {
            if (record.size() > 0) {
                while (record.size() > 0) {
                    String item = record.poll();
                    TreeNode leafnode = new TreeNode(item);
                    leafnode.setCount(1);
                    leafnode.setParent(ancestor);
                    ancestor.addChild(leafnode);
 
                    for (TreeNode f1 : F1) {
                        if (f1.getName().equals(item)) {
                            while (f1.getNextHomonym() != null) {
                                f1 = f1.getNextHomonym();
                            }
                            f1.setNextHomonym(leafnode);
                            break;
                        }
                    }
 
                    addNodes(leafnode, record, F1);
                }
            }
        }
    }
     
    public static class InverseMapper extends
            Mapper<LongWritable, Text, Record, IntWritable> {
        @Override
        public void map(LongWritable key, Text value, Context context)
                throws IOException, InterruptedException {
            String []arr=value.toString().split("\\s+");
            int count=Integer.parseInt(arr[0]);
            Record record=new Record();
            for(int i=1;i<arr.length;i++){
                record.list.add(arr[i]);
            }
            context.write(record, new IntWritable(count));
        }
    }
     
    public static class MaxReducer extends Reducer<Record,IntWritable,IntWritable,Record>{
        public void reduce(Record key,Iterable<IntWritable> values,Context context)throws IOException,InterruptedException{
            int max=-1;
            for(IntWritable value:values){
                int i=value.get();
                if(i>max)
                    max=i;
            }
            context.write(new IntWritable(max), key);
        }
    }
 
 
    @Override
    public int run(String[] arg0) throws Exception {
        Configuration conf=getConf();
        conf.set("mapred.task.timeout", "6000000");
        Job job=new Job(conf);
        job.setJarByClass(DC_FPTree.class);
        FileSystem fs=FileSystem.get(getConf());
         
        FileInputFormat.setInputPaths(job, "/user/orisun/input/data");
        Path outDir=new Path("/user/orisun/output");
        fs.delete(outDir,true);
        FileOutputFormat.setOutputPath(job, outDir);
         
        job.setMapperClass(GroupMapper.class);
        job.setReducerClass(FPReducer.class);
         
        job.setInputFormatClass(TextInputFormat.class);
        job.setOutputFormatClass(TextOutputFormat.class);
        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(Record.class);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputKeyClass(Text.class);
         
        boolean success=job.waitForCompletion(true);
         
        job=new Job(conf);
        job.setJarByClass(DC_FPTree.class);
         
        FileInputFormat.setInputPaths(job, "/user/orisun/output/part-r-*");
        Path outDir2=new Path("/user/orisun/output2");
        fs.delete(outDir2,true);
        FileOutputFormat.setOutputPath(job, outDir2);
         
        job.setMapperClass(InverseMapper.class);
        job.setReducerClass(MaxReducer.class);
        //job.setNumReduceTasks(0);
         
        job.setInputFormatClass(TextInputFormat.class);
        job.setOutputFormatClass(TextOutputFormat.class);
        job.setMapOutputKeyClass(Record.class);
        job.setMapOutputValueClass(IntWritable.class);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputKeyClass(Record.class);
         
        success |= job.waitForCompletion(true);
         
        return success?0:1;
    }
 
    public static void main(String[] args) throws Exception{
        int res=ToolRunner.run(new Configuration(), new DC_FPTree(), args);
        System.exit(res);
    }
}

结束语

在实践中,关联规则挖掘可能并不像人们期望的那么有用。一方面是因为支持度置信度框架会产生过多的规则,并不是每一个规则都是有用的。另一方面大部分的关联规则并不像“啤酒与尿布”这种经典故事这么普遍。关联规则分析是需要技巧的,有时需要用更严格的统计学知识来控制规则的增殖。 

原文来自:博客园(华夏35度)http://www.cnblogs.com/zhangchaoyang 作者:Orisun


你可能感兴趣的:(机器学习)