java 实现fpGrowth算法

输入:

牛奶,鸡蛋,面包,薯片
鸡蛋,爆米花,薯片,啤酒
鸡蛋,面包,薯片
牛奶,鸡蛋,面包,爆米花,薯片,啤酒
牛奶,面包,啤酒
鸡蛋,面包,啤酒
牛奶,面包,薯片
牛奶,鸡蛋,面包,黄油,薯片
牛奶,鸡蛋,黄油,薯片

输出:

啤酒,鸡蛋    3
啤酒,面包    3
牛奶,鸡蛋    4
牛奶,鸡蛋,面包    3
牛奶,鸡蛋,面包,薯片    3
牛奶,鸡蛋,薯片    4
牛奶,面包    5
牛奶,面包,薯片    4
牛奶,薯片    5
鸡蛋,面包    5
鸡蛋,面包,薯片    4
鸡蛋,薯片    6
面包,薯片    5


节点定义:

import org.apache.commons.lang.StringUtils;

public class TreeNode {
    
    private TreeNode parent;
    private String name;
    private int count;
    private Set<TreeNode> children;
    
    public TreeNode(TreeNode parent,String name,int count){
        this.count = count;
        this.parent = parent;
        this.name = name;
    }
    
    public TreeNode(String name,int count){
        this.name = name;
        this.count = count;
    }
    /**
     * 当前节点计数+i
     * @param i
     */
    public void incrementCount(int i){
        this.count = count + i;
    }
    /**
     * 父节点是否包含子节点包含则返回,否则返回null
     * @param key
     * @return
     */
    public TreeNode findChild(String key){
        if(this.children == null){
            return null;
        }
        for(TreeNode child:this.children){
            if(StringUtils.equals(child.name,key)){
                return child;
            }
        }
        return null;
    }
    /**
     * 给父节点增加一个子节点
     * @param child
     * @return
     */
    public TreeNode addChild(TreeNode child){
        if(this.children == null){
            this.children = new HashSet<TreeNode>();
        }
        this.children.add(child);
        return child;
    }
    public boolean isEmpty(){
        return this.children==null || this.children.size()==0;
    }
    
    public TreeNode getParent() {
        return parent;
    }
    public void setParent(TreeNode parent) {
        this.parent = parent;
    }
    public String getName() {
        return name;
    }
    public void setName(String name) {
        this.name = name;
    }
    public int getCount() {
        return count;
    }
    public void setCount(int count) {
        this.count = count;
    }

    public Set<TreeNode> getChildren() {
        return children;
    }

    public void setChildren(Set<TreeNode> children) {
        this.children = children;
    }
    
}

挖掘算法:

import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;

public class FpTree {
    
    private static int support = 3;
    
    public static void main(String[] args) throws IOException{
        //从文件中读取事物数据集
        String file = "D:\\R\\aprior.txt";
        Iterator<String> lineIte = FileUtils.lineIterator(new File(file));
        List<List<String>> transactions = new ArrayList<List<String>>();
        while(lineIte.hasNext()){
            String line = lineIte.next();
            if(StringUtils.isNotEmpty(line)){
                String[] subjects = line.split(",");
                List<String> list = new ArrayList<String>(Arrays.asList(subjects));
                transactions.add(list);
            }
        }
        //初始一个频繁模式集
        List<String> frequences = new LinkedList<String>();
        //开始递归
        digTree(transactions,frequences);
    }
    
    public static void digTree(List<List<String>> transactions,
            List<String> frequences){
        //扫描事物数据集,排序
        final Map<String,Integer> sortedMap = scanAndSort(transactions);
        //没有数据是支持最小支持度了,可以停止了
        if(sortedMap.size() == 0){
            return;
        }
        Map<String,List<TreeNode>> index = new HashMap<String,List<TreeNode>>();
        TreeNode root = buildTree(transactions,index,sortedMap);

        //否则开始从排序最低的项开始 抽出条件模式基,递归挖掘
        List<String> headTable = new ArrayList<String>(sortedMap.keySet());
        Collections.sort(headTable,new Comparator<String>(){
            @Override
            public int compare(String o1, String o2) {
                int i = sortedMap.get(o2)-sortedMap.get(o1); 
                return i != 0 ? i : o1.compareTo(o2);
            }});
        //从项头表最后一项开始挖掘
        for(int i=headTable.size()-1;i>=0;i--){
            String subject = headTable.get(i);
            List<List<String>> frequentModeBases = extract(index.get(subject),root);
            
            LinkedList<String> nextFrequences = new LinkedList<String>(frequences);
            nextFrequences.add(subject);
            if(nextFrequences.size()>1){
                System.out.println(StringUtils.join(nextFrequences,",")+"\t"+sortedMap.get(subject));
            }
            
            digTree(frequentModeBases,nextFrequences);
            
        }
    }
    
    /**
     * 挖掘一个项上面的频繁模式基
     * @param list
     * @param root
     * @return
     */
    public static List<List<String>> extract(List<TreeNode> list,TreeNode root){
        List<List<String>> returnList = new ArrayList<List<String>>();
        for(TreeNode node:list){
            TreeNode parent = node.getParent();
            if(parent.getCount() != -1){
                ArrayList<String> tranc = new ArrayList<String>();
                while(parent.getCount() != -1){
                    tranc.add(parent.getName());
                    parent = parent.getParent();
                }
                for(int i=0;i<node.getCount();i++){
                    returnList.add(tranc);
                }
            }
        }
        return returnList;
    }
    
    /**
     * 构建pf树
     * @param file
     * @param index
     * @param sortedMap
     * @return
     * @throws IOException
     */
    public static TreeNode buildTree(List<List<String>> transactions,
            Map<String,List<TreeNode>> index,
            final Map<String,Integer> sortedMap){
        TreeNode root = new TreeNode(null,"root",-1);
        for(List<String> subjects:transactions){
            Iterator<String> ite = subjects.iterator();
            while(ite.hasNext()){
                String subject = ite.next();
                if(!sortedMap.containsKey(subject)){
                    ite.remove();
                }
            }
            Collections.sort(subjects,new Comparator<String>(){
                @Override
                public int compare(String o1, String o2) {
                    int i = sortedMap.get(o2)-sortedMap.get(o1); 
                    return i != 0 ? i : o1.compareTo(o2);
                }});
            
            TreeNode current = root;
            for(int i=0;i<subjects.size();i++){
                String subject = subjects.get(i);
                TreeNode next = current.findChild(subject);
                if(next == null){
                    TreeNode newNode = new TreeNode(current,subject,1);
                    current.addChild(newNode);
                    current = newNode;
                    List<TreeNode> thisIndex = index.get(subject);
                    if(thisIndex == null){
                        thisIndex = new ArrayList<TreeNode>();
                        index.put(subject, thisIndex);
                    }
                    thisIndex.add(newNode);
                }else{
                    next.incrementCount(1);
                    current = next;
                }
            }
        }
        return root;
    }
    
    /**
     * 扫描排序
     * @param file
     * @return
     * @throws IOException
     */
    public static Map<String,Integer> scanAndSort(List<List<String>> transactions){
        Map<String,Integer> map = new HashMap<String,Integer>();
        //空的就不扫了
        if(transactions.size()==0){
            return map;
        }
        for(List<String> basket:transactions){
            for(String subject:basket){
                Integer count = map.get(subject);
                if (count == null) {
                    map.put(subject, 1);
                } else {
                    map.put(subject, count + 1);
                }
            }
        }
        Iterator<Entry<String,Integer>> ite = map.entrySet().iterator();
        while(ite.hasNext()){
            Entry<String,Integer> entry = ite.next();
            if(entry.getValue() < support){
                ite.remove();
            }
        }
        return map;
    }
}


你可能感兴趣的:(java 实现fpGrowth算法)