输入:
牛奶,鸡蛋,面包,薯片 鸡蛋,爆米花,薯片,啤酒 鸡蛋,面包,薯片 牛奶,鸡蛋,面包,爆米花,薯片,啤酒 牛奶,面包,啤酒 鸡蛋,面包,啤酒 牛奶,面包,薯片 牛奶,鸡蛋,面包,黄油,薯片 牛奶,鸡蛋,黄油,薯片
输出:
啤酒,鸡蛋 3 啤酒,面包 3 牛奶,鸡蛋 4 牛奶,鸡蛋,面包 3 牛奶,鸡蛋,面包,薯片 3 牛奶,鸡蛋,薯片 4 牛奶,面包 5 牛奶,面包,薯片 4 牛奶,薯片 5 鸡蛋,面包 5 鸡蛋,面包,薯片 4 鸡蛋,薯片 6 面包,薯片 5
节点定义:
import org.apache.commons.lang.StringUtils; public class TreeNode { private TreeNode parent; private String name; private int count; private Set<TreeNode> children; public TreeNode(TreeNode parent,String name,int count){ this.count = count; this.parent = parent; this.name = name; } public TreeNode(String name,int count){ this.name = name; this.count = count; } /** * 当前节点计数+i * @param i */ public void incrementCount(int i){ this.count = count + i; } /** * 父节点是否包含子节点包含则返回,否则返回null * @param key * @return */ public TreeNode findChild(String key){ if(this.children == null){ return null; } for(TreeNode child:this.children){ if(StringUtils.equals(child.name,key)){ return child; } } return null; } /** * 给父节点增加一个子节点 * @param child * @return */ public TreeNode addChild(TreeNode child){ if(this.children == null){ this.children = new HashSet<TreeNode>(); } this.children.add(child); return child; } public boolean isEmpty(){ return this.children==null || this.children.size()==0; } public TreeNode getParent() { return parent; } public void setParent(TreeNode parent) { this.parent = parent; } public String getName() { return name; } public void setName(String name) { this.name = name; } public int getCount() { return count; } public void setCount(int count) { this.count = count; } public Set<TreeNode> getChildren() { return children; } public void setChildren(Set<TreeNode> children) { this.children = children; } }
挖掘算法:
import org.apache.commons.io.FileUtils; import org.apache.commons.lang.StringUtils; public class FpTree { private static int support = 3; public static void main(String[] args) throws IOException{ //从文件中读取事物数据集 String file = "D:\\R\\aprior.txt"; Iterator<String> lineIte = FileUtils.lineIterator(new File(file)); List<List<String>> transactions = new ArrayList<List<String>>(); while(lineIte.hasNext()){ String line = lineIte.next(); if(StringUtils.isNotEmpty(line)){ String[] subjects = line.split(","); List<String> list = new ArrayList<String>(Arrays.asList(subjects)); transactions.add(list); } } //初始一个频繁模式集 List<String> frequences = new LinkedList<String>(); //开始递归 digTree(transactions,frequences); } public static void digTree(List<List<String>> transactions, List<String> frequences){ //扫描事物数据集,排序 final Map<String,Integer> sortedMap = scanAndSort(transactions); //没有数据是支持最小支持度了,可以停止了 if(sortedMap.size() == 0){ return; } Map<String,List<TreeNode>> index = new HashMap<String,List<TreeNode>>(); TreeNode root = buildTree(transactions,index,sortedMap); //否则开始从排序最低的项开始 抽出条件模式基,递归挖掘 List<String> headTable = new ArrayList<String>(sortedMap.keySet()); Collections.sort(headTable,new Comparator<String>(){ @Override public int compare(String o1, String o2) { int i = sortedMap.get(o2)-sortedMap.get(o1); return i != 0 ? i : o1.compareTo(o2); }}); //从项头表最后一项开始挖掘 for(int i=headTable.size()-1;i>=0;i--){ String subject = headTable.get(i); List<List<String>> frequentModeBases = extract(index.get(subject),root); LinkedList<String> nextFrequences = new LinkedList<String>(frequences); nextFrequences.add(subject); if(nextFrequences.size()>1){ System.out.println(StringUtils.join(nextFrequences,",")+"\t"+sortedMap.get(subject)); } digTree(frequentModeBases,nextFrequences); } } /** * 挖掘一个项上面的频繁模式基 * @param list * @param root * @return */ public static List<List<String>> extract(List<TreeNode> list,TreeNode root){ List<List<String>> returnList = new ArrayList<List<String>>(); for(TreeNode node:list){ TreeNode parent = node.getParent(); if(parent.getCount() != -1){ ArrayList<String> tranc = new ArrayList<String>(); while(parent.getCount() != -1){ tranc.add(parent.getName()); parent = parent.getParent(); } for(int i=0;i<node.getCount();i++){ returnList.add(tranc); } } } return returnList; } /** * 构建pf树 * @param file * @param index * @param sortedMap * @return * @throws IOException */ public static TreeNode buildTree(List<List<String>> transactions, Map<String,List<TreeNode>> index, final Map<String,Integer> sortedMap){ TreeNode root = new TreeNode(null,"root",-1); for(List<String> subjects:transactions){ Iterator<String> ite = subjects.iterator(); while(ite.hasNext()){ String subject = ite.next(); if(!sortedMap.containsKey(subject)){ ite.remove(); } } Collections.sort(subjects,new Comparator<String>(){ @Override public int compare(String o1, String o2) { int i = sortedMap.get(o2)-sortedMap.get(o1); return i != 0 ? i : o1.compareTo(o2); }}); TreeNode current = root; for(int i=0;i<subjects.size();i++){ String subject = subjects.get(i); TreeNode next = current.findChild(subject); if(next == null){ TreeNode newNode = new TreeNode(current,subject,1); current.addChild(newNode); current = newNode; List<TreeNode> thisIndex = index.get(subject); if(thisIndex == null){ thisIndex = new ArrayList<TreeNode>(); index.put(subject, thisIndex); } thisIndex.add(newNode); }else{ next.incrementCount(1); current = next; } } } return root; } /** * 扫描排序 * @param file * @return * @throws IOException */ public static Map<String,Integer> scanAndSort(List<List<String>> transactions){ Map<String,Integer> map = new HashMap<String,Integer>(); //空的就不扫了 if(transactions.size()==0){ return map; } for(List<String> basket:transactions){ for(String subject:basket){ Integer count = map.get(subject); if (count == null) { map.put(subject, 1); } else { map.put(subject, count + 1); } } } Iterator<Entry<String,Integer>> ite = map.entrySet().iterator(); while(ite.hasNext()){ Entry<String,Integer> entry = ite.next(); if(entry.getValue() < support){ ite.remove(); } } return map; } }