Apriori算法原理:
如果某个项集是频繁的,那么它所有的子集也是频繁的。如果一个项集是非频繁的,那么它所有的超集也是非频繁的。
示意图
图一:
图二:
package cn.ffr.frequent.apriori; import java.io.BufferedReader; import java.io.InputStreamReader; import java.net.URL; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; /** * Apriori的核心代码实现 * @author [email protected] */ public class Apriori { public static final String STRING_SPLIT = ","; /** * 主要的计算方法 * @param data 数据集 * @param minSupport 最小支持度 * @param maxLoop 最大执行次数,设NULL为获取最终结果 * @param containSet 结果中必须包含的子集 * @return */ public Map<String, Double> compute(List<String[]> data, Double minSupport, Integer maxLoop, String[] containSet){ //校验 if(data == null || data.size() <= 0){ return null; } //初始化 Map<String, Double> result = new HashMap<String, Double>(); Object[] itemSet = getDataUnitSet(data); int loop = 0; //核心循环处理过程 while(true){ //重要步骤一:合并,产生新的频繁集 Set<String> keys = combine(result.keySet(), itemSet); result.clear();//移除之前的结果 for(String key : keys){ result.put(key, computeSupport(data, key.split(STRING_SPLIT))); } //重要步骤二:修剪,去除支持度小于条件的。 cut(result, minSupport, containSet); loop++; //输出计算过程 System.out.println("loop ["+loop+"], result : "+result); //循环结束条件 if(result.size() <= 0){ break; } if(maxLoop != null && maxLoop > 0 && loop >= maxLoop){//可控制循环执行次数 break; } } return result; } /** * 计算子集的支持度 * * 支持度 = 子集在数据集中的数据项 / 总的数据集的数据项 * * 数据项的意思是一条数据。 * @param data 数据集 * @param subSet 子集 * @return */ public Double computeSupport(List<String[]> data, String[] subSet){ Integer value = 0; for(int i = 0; i < data.size(); i++){ if(contain(data.get(i), subSet)){ value ++; } } return value*1.0/data.size(); } /** * 获得初始化唯一的数据集,用于初始化 * @param data * @return */ public Object[] getDataUnitSet(List<String[]> data){ List<String> uniqueKeys = new ArrayList<String>(); for(String[] dat : data){ for(String da : dat){ if(!uniqueKeys.contains(da)){ uniqueKeys.add(da); } } } return uniqueKeys.toArray(); } /** * 合并src和target来获取频繁集 * 增加频繁集的计算维度 * @param src * @param target * @return */ public Set<String> combine(Set<String> src, Object[] target){ Set<String> dest = new HashSet<String>(); if(src == null || src.size() <= 0){ for(Object t : target){ dest.add(t.toString()); } return dest; } for(String s : src){ for(Object t : target){ if(s.indexOf(t.toString())<0){ String key = s+STRING_SPLIT+t; if(!contain(dest, key)){ dest.add(key); } } } } return dest; } /** * dest集中是否包含了key * @param dest * @param key * @return */ public boolean contain(Set<String> dest, String key){ for(String d : dest){ if(equal(d.split(STRING_SPLIT), key.split(STRING_SPLIT))){ return true; } } return false; } /** * 移除结果中,支持度小于所需要的支持度的结果。 * @param result * @param minSupport * @return */ public Map<String, Double> cut(Map<String, Double> result, Double minSupport, String[] containSet){ for(Object key : result.keySet().toArray()){//防止 java.util.ConcurrentModificationException,使用keySet().toArray() if(minSupport != null && minSupport > 0 && minSupport < 1 && result.get(key) < minSupport){//比较支持度 result.remove(key); } if(containSet != null && containSet.length > 0 && !contain(key.toString().split(STRING_SPLIT), containSet)){ result.remove(key); } } return result; } /** * src中是否包含dest,需要循环遍历查询 * @param src * @param dest * @return */ public static boolean contain(String[] src, String[] dest){ for(int i = 0; i < dest.length; i++){ int j = 0; for(; j < src.length; j++){ if(src[j].equals(dest[i])){ break; } } if(j == src.length){ return false;//can not find } } return true; } /** * src是否与dest相等 * @param src * @param dest * @return */ public boolean equal(String[] src, String[] dest){ if(src.length == dest.length && contain(src, dest)){ return true; } return false; } /** * 主测试方法 * 测试方法,挨个去掉注释,进行测试。 */ public static void main(String[] args) throws Exception{ //test 1 // List<String[]> data = loadSmallData(); // Long start = System.currentTimeMillis(); // Map<String, Double> result = new Apriori().compute(data, 0.5, 3, null);//求支持度大于指定值 // Long end = System.currentTimeMillis(); // System.out.println("Apriori Result [costs:"+(end-start)+"ms]: "); // for(String key : result.keySet()){ // System.out.println("\tFrequent Set=["+key+"] & Support=["+result.get(key)+"];"); // } //test 2 // List<String[]> data = loadMushRoomData(); // Long start = System.currentTimeMillis(); // Map<String, Double> result = new Apriori().compute(data, 0.3, 4, new String[]{"2"});//求支持度大于指定值 // Long end = System.currentTimeMillis(); // System.out.println("Apriori Result [costs:"+(end-start)+"ms]: "); // for(String key : result.keySet()){ // System.out.println("\tFrequent Set=["+key+"] & Support=["+result.get(key)+"];"); // } //test 3 List<String[]> data = loadChessData(); Long start = System.currentTimeMillis(); Map<String, Double> result = new Apriori().compute(data, 0.95, 3, null);//求支持度大于指定值 Long end = System.currentTimeMillis(); System.out.println("Apriori Result [costs:"+(end-start)+"ms]: "); for(String key : result.keySet()){ System.out.println("\tFrequent Set=["+key+"] & Support=["+result.get(key)+"];"); } } /* * SmallData: minSupport 0.5, maxLoop 3, containSet null, [costs: 16ms] * MushRoomData: minSupport 0.3, maxLoop 4, containSet {"2"}, [costs: 103250ms] * ChessData: minSupport 0.95, maxLoop 34, containSet {null, [costs: 9718ms] */ //测试数据集-1 public static List<String[]> loadSmallData() throws Exception{ List<String[]> data = new ArrayList<String[]>(); data.add(new String[]{"d1","d3","d4"}); data.add(new String[]{"d2","d3","d5"}); data.add(new String[]{"d1","d2","d3","d5"}); data.add(new String[]{"d2","d5"}); return data; } //测试数据集-2 public static List<String[]> loadMushRoomData() throws Exception{ String link = "http://fimi.ua.ac.be/data/mushroom.dat"; URL url = new URL(link); BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream())); String temp = reader.readLine(); List<String[]> result = new ArrayList<String[]>(); int lineNumber = 0; while(temp != null){ System.out.println("reading data... [No."+(++lineNumber)+"]"); String[] item = temp.split(" "); result.add(item); temp = reader.readLine(); } reader.close(); return result; } //测试数据集-3 public static List<String[]> loadChessData() throws Exception{ String link = "http://fimi.ua.ac.be/data/chess.dat"; URL url = new URL(link); BufferedReader reader = new BufferedReader(new InputStreamReader(url.openStream())); String temp = reader.readLine(); List<String[]> result = new ArrayList<String[]>(); int lineNumber = 0; while(temp != null){ System.out.println("reading data... [No."+(++lineNumber)+"]"); String[] item = temp.split(" "); result.add(item); temp = reader.readLine(); } reader.close(); return result; } }
算法原理: