数据挖掘-关联分析-Apriori算法Java实现 支持度+置信度

apriori算法是最基本的发现频繁项集的算法,它的名字也体现了它的思想——先验,采用逐层搜索迭代的方法,挖掘任何可能的项集,k项集用于挖掘k+1项集。

先验性质

频繁项集的所有非空子集也一定是频繁的

该性质体现了项集挖掘中的反单调性,如果k项集不是频繁的,那么k+1项集一定也不是。基于这一点,算法的基本思想为:

step 1:连接

    为了搜索k项集,将k-1项集自连接产生候选k项集,称为候选集。

    为了有效的实现连接,首先对每一项进行排序。其次,若满足连接的条件,则进行连接。

    连接的条件,前k-2项相同,k-1项不同

step 2:剪枝

    k项集的每一个k-1项子集都存在与k-1项集,并且支持度满足最小支持度阀值。

伪代码:

C<k>:candidata itemset of size k
L<k>:frequent itemset of size k
L<1>=frequent items
 for(k=1;L<k>!=null;k++)
    C<k+1>=candidates generated from L<k>
    for transaction t in dataset
        increment the count of all candidates in C<k+1> that are contained in t     L<k+1>=candidates in C<k+1> with support>=min_support
return
Java代码实现方式:

抽象了一个项集实体类,并实现是否可以合并的方法,这个方法最初是使用TreeSet.headSet来实现的,但是在测试时发现性能瓶颈都产生在这个方法上,并造成OOM,很是费解,待研究清楚后总结一下。

/**
 * 
 */
package org.waitingfortime.datamining.association;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

/**
 * @author mazhiyuan
 * 
 */
public class Apriori {
	private int minNum;// 最小支持数
	private List<Set<Integer>> records;
	private String output;
	private List<List<ItemSet>> result = new ArrayList<List<ItemSet>>();

	public Apriori(double minDegree, String input, String output) {
		this.output = output;
		init(input);
		if (records.size() == 0) {
			System.err.println("不符合计算条件。退出!");
			System.exit(1);
		}
		minNum = (int) (minDegree * records.size());
	}

	private void init(String path) {
		// TODO Auto-generated method stub
		records = new ArrayList<Set<Integer>>();
		try {
			BufferedReader br = new BufferedReader(new FileReader(
					new File(path)));

			String line = null;
			Set<Integer> record;
			while ((line = br.readLine()) != null) {
				if (!"".equals(line.trim())) {
					record = new TreeSet<Integer>();
					String[] items = line.split(" ");
					for (String item : items) {
						record.add(Integer.valueOf(item));
					}
					records.add(record);
				}
			}

			br.close();
		} catch (IOException e) {
			System.err.println("读取事务文件失败。");
		}
	}

	private List<ItemSet> first() {
		// TODO Auto-generated method stub
		List<ItemSet> first = new ArrayList<ItemSet>();
		Map<Integer, Integer> _first = new HashMap<Integer, Integer>();
		for (Set<Integer> si : records)
			for (Integer i : si) {
				if (_first.get(i) == null)
					_first.put(i, 1);
				else
					_first.put(i, _first.get(i) + 1);
			}

		for (Integer i : _first.keySet())
			if (_first.get(i) >= minNum)
				first.add(new ItemSet(i, _first.get(i)));

		return first;
	}

	private void loop(List<ItemSet> items) {
		// TODO Auto-generated method stub
		List<ItemSet> copy = new ArrayList<ItemSet>(items);
		List<ItemSet> res = new ArrayList<ItemSet>();
		int size = items.size();

		// 连接
		for (int i = 0; i < size; i++)
			for (int j = i + 1; j < size; j++)
				if (copy.get(i).isMerge(copy.get(j))) {
					ItemSet is = new ItemSet(copy.get(i));
					is.merge(copy.get(j).item.last());
					res.add(is);
				}
		// 剪枝
		pruning(copy, res);

		if (res.size() != 0) {
			result.add(res);
			loop(res);
		}
	}

	private void pruning(List<ItemSet> pre, List<ItemSet> res) {
		// TODO Auto-generated method stub
		// step 1 k项集的子集属于k-1项集
		Iterator<ItemSet> ir = res.iterator();
		while (ir.hasNext()) {
			// 获取所有k-1项子集
			ItemSet now = ir.next();
			List<List<Integer>> ss = subSet(now);
			// 判断是否在pre集中
			boolean flag = false;
			for (List<Integer> li : ss) {
				if (flag)
					break;
				for (ItemSet pis : pre) {
					if (pis.item.containsAll(li)) {
						flag = false;
						break;
					}
					flag = true;
				}
			}
			if (flag) {
				ir.remove();
				continue;
			}
			// step 2 支持度
			int i = 0;
			for (Set<Integer> sr : records) {
				if (sr.containsAll(now.item))
					i++;

				now.value = i;
			}
			if (now.value < minNum)
				ir.remove();
		}
	}

	private List<List<Integer>> subSet(ItemSet is) {
		// TODO Auto-generated method stub
		List<Integer> li = new ArrayList<Integer>(is.item);
		List<List<Integer>> res = new ArrayList<List<Integer>>();
		for (int i = 0, j = li.size(); i < j; i++) {
			List<Integer> _li = new ArrayList<Integer>(li);
			_li.remove(i);
			res.add(_li);
		}
		return res;
	}

	private void output() throws FileNotFoundException {
		if (result.size() == 0) {
			System.err.println("无结果集。退出!");
			return;
		}
		FileOutputStream out = new FileOutputStream(output);
		PrintStream ps = new PrintStream(out);
		for (List<ItemSet> li : result) {
			ps.println("=============频繁"+li.get(0).item.size()+"项集=============");
			for (ItemSet is : li)
				ps.println(is.item + " : " + is.value);
			ps.println("=====================================");
		}
	}

	/**
	 * @param args
	 * @throws FileNotFoundException
	 */
	public static void main(String[] args) throws FileNotFoundException {
		// TODO Auto-generated method stub
		long begin = System.currentTimeMillis();
		Apriori apriori = new Apriori(0.25,
				"/home/mazhiyuan/code/mushroom.dat",
				"/home/mazhiyuan/code/mout.data");
		// apriori.first();//频繁1项集
		apriori.loop(apriori.first());
		apriori.output();
		System.out.println((System.currentTimeMillis()) - begin);
	}
}

class ItemSet {
	TreeSet<Integer> item;
	int value;

	ItemSet(ItemSet is) {
		this.item = new TreeSet<Integer>(is.item);
	}

	ItemSet() {
		item = new TreeSet<Integer>();
	}

	ItemSet(int i, int v) {
		this();
		merge(i);
		setValue(v);
	}

	void setValue(int i) {
		this.value = i;
	}

	void merge(int i) {
		item.add(i);
	}

	boolean isMerge(ItemSet other) {
		if (other == null || other.item.size() != item.size())
			return false;
		// 前k-1项相同,最后一项不同,满足连接条件
		/*
		 * Iterator<Integer> i = item.headSet(item.last()).iterator();
		 * Iterator<Integer> o =
		 * other.item.headSet(other.item.last()).iterator(); while (i.hasNext() &&
		 * o.hasNext()) if (i.next() != o.next()) return false;
		 */
		Iterator<Integer> i = item.iterator();
		Iterator<Integer> o = other.item.iterator();
		int n = item.size();
		while (i.hasNext() && o.hasNext() && --n > 0)
			if (i.next() != o.next())
				return false;

		return !(item.last() == other.item.last());
	}
}
使用mushroom数据集,整个运行时间只有大概6s,性能还算满意。

这个代码只是计算了频繁项集,还没有计算关联规则和置信度,稍后补上。

=========补充了关联规则的生成======== 比想象的要麻烦一点

关联规则可以是双向的,confidence(A-->B)=P(A|B)=support(A&B)/support(A)

所以在计算k项集的关联规则时,其分母都是k项集的支持度,分子为k-1项集的支持度,以及对应1项集的支持度

/**
 * 
 */
package org.waitingfortime.datamining.association;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

/**
 * @author mazhiyuan
 * 
 */
public class Apriori {
	private int minNum;// 最小支持数
	private double minCon;// 最小置信度
	private List<Set<Integer>> records;// 原始数据
	private String output;// 输出路径
	private List<List<ItemSet>> result = new ArrayList<List<ItemSet>>();// 频繁项集结果
	private List<ItemSet> fth;// 频繁1项集

	public Apriori(double minDegree, double minCon, String input, String output) {
		this.output = output;
		this.minCon = minCon;
		init(input);
		if (records.size() == 0) {
			System.err.println("不符合计算条件。退出!");
			System.exit(1);
		}
		minNum = (int) (minDegree * records.size());
	}

	private void init(String path) {
		// TODO Auto-generated method stub
		records = new ArrayList<Set<Integer>>();
		try {
			BufferedReader br = new BufferedReader(new FileReader(
					new File(path)));

			String line = null;
			Set<Integer> record;
			while ((line = br.readLine()) != null) {
				if (!"".equals(line.trim())) {
					record = new TreeSet<Integer>();
					String[] items = line.split(" ");
					for (String item : items) {
						record.add(Integer.valueOf(item));
					}
					records.add(record);
				}
			}

			br.close();
		} catch (IOException e) {
			System.err.println("读取事务文件失败。");
		}
	}

	private void first() {
		// TODO Auto-generated method stub
		fth = new ArrayList<ItemSet>();
		Map<Integer, Integer> first = new HashMap<Integer, Integer>();
		for (Set<Integer> si : records)
			for (Integer i : si) {
				if (first.get(i) == null)
					first.put(i, 1);
				else
					first.put(i, first.get(i) + 1);
			}

		for (Integer i : first.keySet())
			if (first.get(i) >= minNum)
				fth.add(new ItemSet(i, first.get(i)));

	}

	private void loop(List<ItemSet> items) {
		// TODO Auto-generated method stub
		List<ItemSet> copy = new ArrayList<ItemSet>(items);
		List<ItemSet> res = new ArrayList<ItemSet>();
		int size = items.size();

		// 连接
		for (int i = 0; i < size; i++)
			for (int j = i + 1; j < size; j++)
				if (copy.get(i).isMerge(copy.get(j))) {
					ItemSet is = new ItemSet(copy.get(i));
					is.merge(copy.get(j).item.last());
					res.add(is);
				}
		// 剪枝
		pruning(copy, res);

		if (res.size() != 0) {
			result.add(res);
			loop(res);
		}
	}

	private void pruning(List<ItemSet> pre, List<ItemSet> res) {
		// TODO Auto-generated method stub
		// step 1 k项集的子集属于k-1项集
		Iterator<ItemSet> ir = res.iterator();
		while (ir.hasNext()) {
			// 获取所有k-1项子集
			ItemSet now = ir.next();
			Map<Integer, List<Integer>> ss = subSet(now);
			// 判断是否在pre集中
			boolean flag = false;
			for (List<Integer> li : ss.values()) {
				if (flag)
					break;
				for (ItemSet pis : pre) {
					if (pis.item.containsAll(li)) {
						flag = false;
						break;
					}
					flag = true;
				}
			}
			if (flag) {
				ir.remove();
				continue;
			}
			// step 2 支持度
			int i = 0;
			for (Set<Integer> sr : records) {
				if (sr.containsAll(now.item))
					i++;

				now.support = i;
			}
			if (now.support < minNum) {
				ir.remove();
				continue;
			}
			// 产生关联规则
			double deno = now.support;
			for (Map.Entry<Integer, List<Integer>> me : ss.entrySet()) {
				ItemCon ic = new ItemCon(me.getKey(), me.getValue());
				int nume = 0;

				for (ItemSet f : fth)
					if (f.item.contains(me.getKey())) {
						nume = f.support;
						break;
					}
				if (deno / nume > minCon) {
					now.calcon(ic);
					ic.setC1(deno / nume);
				}
				for (ItemSet pis : pre)
					if (pis.item.size() == me.getValue().size()
							&& pis.item.containsAll(me.getValue())) {
						nume = pis.support;
						break;
					}
				if (deno / nume > minCon)
					ic.setC2(deno / nume);
			}
		}
	}

	private Map<Integer, List<Integer>> subSet(ItemSet is) {
		// TODO Auto-generated method stub
		List<Integer> li = new ArrayList<Integer>(is.item);
		Map<Integer, List<Integer>> res = new HashMap<Integer, List<Integer>>();
		for (int i = 0, j = li.size(); i < j; i++) {
			List<Integer> _li = new ArrayList<Integer>(li);
			_li.remove(i);
			res.put(li.get(i), _li);
		}
		return res;
	}

	private void output() throws FileNotFoundException {
		if (result.size() == 0) {
			System.err.println("无结果集。退出!");
			return;
		}
		FileOutputStream out = new FileOutputStream(output);
		PrintStream ps = new PrintStream(out);
		for (List<ItemSet> li : result) {
			ps.println("=============频繁" + li.get(0).item.size()
					+ "项集=============");
			for (ItemSet is : li) {
				ps.println(is.item + " : " + is.support);
				ps.println();
				if (is.ics.size() != 0) {
					ps.println("******关联规则******");
					for (ItemCon ic : is.ics) {
						ps.println(ic.i + " ---> " + ic.li + " con: "
								+ ic.confidence1);
						if (ic.confidence2 > minCon)
							ps.println(ic.li + " ---> " + ic.i + " con: "
									+ ic.confidence2);
					}
					ps.println("******************");
					ps.println();
				}
			}
			ps.println("=====================================");
		}

		ps.close();
	}

	/**
	 * @param args
	 * @throws FileNotFoundException
	 */
	public static void main(String[] args) throws FileNotFoundException {
		// TODO Auto-generated method stub
		long begin = System.currentTimeMillis();
		Apriori apriori = new Apriori(0.25, 0.5,
				"/home/mazhiyuan/code/mushroom.dat",
				"/home/mazhiyuan/code/mout.data");
		// apriori.first();//频繁1项集
		apriori.first();
		apriori.loop(apriori.fth);

		apriori.output();
		System.out.println("共耗时:" + ((System.currentTimeMillis()) - begin)
				+ "ms");
	}
}

class ItemSet {
	TreeSet<Integer> item;
	int support;
	List<ItemCon> ics = new ArrayList<ItemCon>(); // 关联规则结果

	ItemSet(ItemSet is) {
		this.item = new TreeSet<Integer>(is.item);
	}

	ItemSet() {
		item = new TreeSet<Integer>();
	}

	ItemSet(int i, int v) {
		this();
		merge(i);
		setValue(v);
	}

	void setValue(int i) {
		this.support = i;
	}

	void merge(int i) {
		item.add(i);
	}

	void calcon(ItemCon ic) {
		ics.add(ic);
	}

	boolean isMerge(ItemSet other) {
		if (other == null || other.item.size() != item.size())
			return false;
		// 前k-1项相同,最后一项不同,满足连接条件
		/*
		 * Iterator<Integer> i = item.headSet(item.last()).iterator();
		 * Iterator<Integer> o =
		 * other.item.headSet(other.item.last()).iterator(); while (i.hasNext()
		 * && o.hasNext()) if (i.next() != o.next()) return false;
		 */
		Iterator<Integer> i = item.iterator();
		Iterator<Integer> o = other.item.iterator();
		int n = item.size();
		while (i.hasNext() && o.hasNext() && --n > 0)
			if (i.next() != o.next())
				return false;

		return !(item.last() == other.item.last());
	}
}

class ItemCon {
	Integer i;
	List<Integer> li;
	double confidence1;
	double confidence2;

	ItemCon(Integer i, List<Integer> li) {
		this.i = i;
		this.li = li;
	}

	void setC1(double c1) {
		this.confidence1 = c1;
	}

	void setC2(double c2) {
		this.confidence2 = c2;
	}
}

Apriori算法本身的性能就是一大问题,产生太多的候选集,FP-TREE算法规避了这一问题,使得频繁项集的挖掘性能提高了至少一个量级,下一篇重点介绍这个算法。

你可能感兴趣的:(java,Apriori,关联分析,频繁项集)