Java数据挖掘,数据关联规则之PF-Growth算法,压缩结果集

FP-Growth算法:

首先,把所有的事务集的汉字词语都转化成数字,好处如下:

  1. 数字之间的比较远比字符串快;

  2. 减少内存的使用;

所以前提就是把所有的汉字词语转化为数字

eg:汉字词语

牛奶,鸡蛋,面包,薯片

鸡蛋,爆米花,薯片,啤酒

鸡蛋,面包,薯片

转化为一一对应的数字

1,2,3,4

2,5,4,6

2,3,4


FP-Growth的树结构的java代码如下,把String->integer:

import java.util.HashSet;
import java.util.Set;

public class TreeNode {
	private TreeNode parent;
	private int nameNO;
	private int count;
	private Set<TreeNode> children;

	public TreeNode(TreeNode parent, int nameNO, int count) {
		this.count = count;
		this.parent = parent;
		this.nameNO = nameNO;
	}

	public TreeNode(int nameNO, int count) {
		this.nameNO = nameNO;
		this.count = count;
	}

	/**
	 * 当前节点计数+i
	 * 
	 * @param i
	 */
	public void incrementCount(int i) {
		this.count = count + i;
	}

	/**
	 * 父节点是否包含子节点包含则返回,否则返回null
	 * 
	 * @param key
	 * @return
	 */
	public TreeNode findChild(int key) {
		if (this.children == null) {
			return null;
		}
		for (TreeNode child : this.children) {
			if (child.nameNO == key) {
				return child;
			}
		}
		return null;
	}

	/**
	 * 给父节点增加一个子节点
	 * 
	 * @param child
	 * @return
	 */
	public TreeNode addChild(TreeNode child) {
		if (this.children == null) {
			this.children = new HashSet<TreeNode>();
		}
		this.children.add(child);
		return child;
	}

	public boolean isEmpty() {
		return this.children == null || this.children.size() == 0;
	}

	public TreeNode getParent() {
		return parent;
	}

	public void setParent(TreeNode parent) {
		this.parent = parent;
	}

	public int getNameCount() {
		return nameNO;
	}

	public void setNameCount(int nameNO) {
		this.nameNO = nameNO;
	}

	public int getCount() {
		return count;
	}

	public void setCount(int count) {
		this.count = count;
	}

	public Set<TreeNode> getChildren() {
		return children;
	}

	public void setChildren(Set<TreeNode> children) {
		this.children = children;
	}
}

FP-Growth算法主要算法如下:

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.OutputStreamWriter;
import java.util.*;
import java.util.Map.Entry;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;

public class FrequentItemSets {
	private static Logger logger = Logger.getLogger(FrequentItemSets.class
			.getName());

	//设定支持度
	private static int ABSOLUTE_SUPPORT = 3;
	//用于输出频繁一项的标志位
	private static int flag = 1;
	private static BufferedWriter bw = null;
	private static File tempFile = null;
	private static File fileSortMap = null;
	private static BufferedWriter bwSortMap = null;
	private static List<Integer> frequentItemsData = new ArrayList<Integer>();

	public static void main(String[] args) throws Exception {
		FrequentItemSets frequentItemSets = new FrequentItemSets();
		frequentItemSets.beginFrequentItemSets(
				"C:\\Users\\angelo\\Desktop\\test\\1.txt",
				"C:\\Users\\angelo\\Desktop\\test\\OneItem.txt",
				"C:\\Users\\angelo\\Desktop\\test\\Items.txt");
	}

	@SuppressWarnings("unchecked")
	public void beginFrequentItemSets(String fromFilePath,
			String toFrequentOneItemFilePath, String frequentItemsSetDataFile)
			throws Exception {
		fileSortMap = new File(toFrequentOneItemFilePath);
		// 从文件中读取事物数据集,这个里面的汉字是经过UTF-8排序过的汉字
		FileOutputStream fos = new FileOutputStream(fileSortMap);
		bwSortMap = new BufferedWriter(new OutputStreamWriter(fos));
		Iterator<String> lineIte = FileUtils
				.lineIterator(new File(fromFilePath));
		List<List<Integer>> transactions = new ArrayList<List<Integer>>();
		while (lineIte.hasNext()) {
			String line = lineIte.next();
			if (StringUtils.isNotEmpty(line) && line.length() != 33) {
				String[] subjects = line.split(",");
				List<String> list = new ArrayList<String>(
						Arrays.asList(subjects));
				List<Integer> intList = new ArrayList<Integer>();
				for (String temp : list) {
					intList.add(Integer.parseInt(temp));
				}
				transactions.add(intList);
			}
		}
		// 初始一个频繁模式集
		List<Integer> frequences = new LinkedList<Integer>();

		tempFile = new File("C:\\Users\\angelo\\Desktop\\test\\Items2.txt");
		if (tempFile.exists()) {
			tempFile.delete();
		}
		tempFile.createNewFile();
		bw = new BufferedWriter(new FileWriter(tempFile));

		digTree(transactions, frequences);
		// set转list 在用UTF-8排序
		List<Integer> frequentItemsDataList = new ArrayList<Integer>(
				frequentItemsData);
		bw.write(listToString(frequentItemsDataList, ",") + "\n");
		bw.flush();
		bw.close();
	}

	public void digTree(List<List<Integer>> transactions,
			List<Integer> frequences) throws Exception {
		// 扫描事物数据集,排序
		final Map<Integer, Integer> sortedMap = scanAndSort(transactions);
		// 没有数据是支持最小支持度了,可以停止了
		if (sortedMap.size() == 0) {
			return;
		}
		Map<Integer, List<TreeNode>> index = new HashMap<Integer, List<TreeNode>>();
		TreeNode root = buildTree(transactions, index, sortedMap);
		// 否则开始从排序最低的项开始 抽出条件模式基,递归挖掘
		List<Integer> headTable = new ArrayList<Integer>(sortedMap.keySet());
		Collections.sort(headTable, new Comparator<Integer>() {
			@Override
			public int compare(Integer o1, Integer o2) {
				int i = sortedMap.get(o2) - sortedMap.get(o1);
				return i != 0 ? i : o1.compareTo(o2);
			}
		});

		//输出频繁一项集数据
		if (flag == 1) {
			for (Integer keyWord : headTable) {
				bwSortMap.write(keyWord + ",");
			}
			bwSortMap.flush();
			bwSortMap.close();
			flag++;
		}

		// 从项头表最后一项开始挖掘
		for (int i = headTable.size() - 1; i >= 0; i--) {
			Integer subject = headTable.get(i);
			List<List<Integer>> frequentModeBases = extract(index.get(subject),
					root);

			LinkedList<Integer> nextFrequences = new LinkedList<Integer>(
					frequences);
			nextFrequences.add(subject);
			if (nextFrequences.size() > 1) {
				try {
				        //重点:数据的压缩在这里
					List<Integer> tempList = new ArrayList<Integer>();
					for (Integer temp : nextFrequences) {
						tempList.add(temp);
					}
					if (frequentItemsData.size() == 0) {
						frequentItemsData.addAll(tempList);
					} else {
						List<Integer> tempFrequentList = new ArrayList<Integer>();
						tempFrequentList.addAll(frequentItemsData);

						List<Integer> saveTempList = new ArrayList<Integer>();
						saveTempList.addAll(tempList);
						tempFrequentList.removeAll(tempList);
						tempList.removeAll(frequentItemsData);
						if (tempFrequentList.size() == 0) {
							frequentItemsData.clear();
							frequentItemsData.addAll(saveTempList);
						} else if (tempList.size() == 0) {
							continue;
						} else {
							List<String> frequentItemsDataList = new ArrayList<String>();
							for (Integer tempInt : frequentItemsData) {
								frequentItemsDataList.add(tempInt + "");
							}
							Collections.sort(frequentItemsDataList,
									new SortChineseKeywords());
							bw.write(listToString(frequentItemsDataList, ",")
									+ "\n");
							bw.flush();
							frequentItemsData.clear();
							frequentItemsData.addAll(saveTempList);
						}
					}
				} catch (Exception ex) {
					logger.error(ex.getMessage());
				}
			}
			digTree(frequentModeBases, nextFrequences);
		}
	}

	public List<List<Integer>> extract(List<TreeNode> list, TreeNode root) {
		List<List<Integer>> returnList = new ArrayList<List<Integer>>();
		for (TreeNode node : list) {
			TreeNode parent = node.getParent();
			if (parent.getCount() != -1) {
				ArrayList<Integer> tranc = new ArrayList<Integer>();
				while (parent.getCount() != -1) {
					tranc.add(parent.getNameCount());
					parent = parent.getParent();
				}
				for (int i = 0; i < node.getCount(); i++) {
					returnList.add(tranc);
				}
			}
		}
		return returnList;
	}

	public TreeNode buildTree(List<List<Integer>> transactions,
			Map<Integer, List<TreeNode>> index,
			final Map<Integer, Integer> sortedMap) {
		TreeNode root = new TreeNode(null, -1, -1);
		for (List<Integer> subjects : transactions) {
			Iterator<Integer> ite = subjects.iterator();
			while (ite.hasNext()) {
				Integer subject = ite.next();
				if (!sortedMap.containsKey(subject)) {
					ite.remove();
				}
			}
			Collections.sort(subjects, new Comparator<Integer>() {
				@Override
				public int compare(Integer o1, Integer o2) {
					int i = sortedMap.get(o2) - sortedMap.get(o1);
					return i != 0 ? i : o1.compareTo(o2);
				}
			});

			TreeNode current = root;
			for (int i = 0; i < subjects.size(); i++) {
				Integer subject = subjects.get(i);
				TreeNode next = current.findChild(subject);
				if (next == null) {
					TreeNode newNode = new TreeNode(current, subject, 1);
					current.addChild(newNode);
					current = newNode;
					List<TreeNode> thisIndex = index.get(subject);
					if (thisIndex == null) {
						thisIndex = new ArrayList<TreeNode>();
						index.put(subject, thisIndex);
					}
					thisIndex.add(newNode);
				} else {
					next.incrementCount(1);
					current = next;
				}
			}
		}
		return root;
	}

	public Map<Integer, Integer> scanAndSort(List<List<Integer>> transactions) {
		Map<Integer, Integer> map = new HashMap<Integer, Integer>();
		// 空的就不扫了
		if (transactions.size() == 0) {
			return map;
		}
		for (List<Integer> basket : transactions) {
			for (Integer subject : basket) {
				Integer count = map.get(subject);
				if (count == null) {
					map.put(subject, 1);
				} else {
					map.put(subject, count + 1);
				}
			}
		}
		Iterator<Entry<Integer, Integer>> ite = map.entrySet().iterator();
		while (ite.hasNext()) {
			Entry<Integer, Integer> entry = ite.next();
			if (entry.getValue() < ABSOLUTE_SUPPORT) {
				ite.remove();
			}
		}
		return map;
	}

	public static String listToString(List list, String reg) {
		// 默认用,符号
		if (reg == null || "".equals(reg)) {
			reg = ",";
		}
		StringBuffer sb = new StringBuffer();
		if (null == list || list.size() == 0) {
			return null;
		}
		for (Iterator iter = list.iterator(); iter.hasNext();) {
			sb.append(iter.next()).append(reg);
		}
		int length = sb.length();
		if (length > 0) {
			sb = sb.delete(length - 1, length);
		}
		return sb.toString();
	}
}

ChineseUTF-8排序算法:

import java.text.Collator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

public class SortChineseKeywords implements Comparator<String> {
	Collator cmp = Collator.getInstance(java.util.Locale.CHINA);

	@Override
	public int compare(String o1, String o2) {
		if (cmp.compare(o1, o2) > 0) {
			return 1;
		} else if (cmp.compare(o1, o2) < 0) {
			return -1;
		}
		return 0;
	}
}

如需改进,非常感谢!

你可能感兴趣的:(Mahout,FP-Growth,压缩数据关联结果集)