数据挖掘笔记-分类-决策树-ID3和C4.5

之前一直做的都是J2EE,最近开始接触数据挖掘,特做笔记记录一下。第一次写东西,写的不好,望大家谅解。先上一些基础概念,大致了解下决策树这个东西:

决策树(decision tree)是一个树结构(可以是二叉树或非二叉树)。其每个非叶节点表示一个特征属性上的测试,每个分支代表这个特征属性在某个值域上的输出,而每个叶节点存放一个类别。使用决策树进行决策的过程就是从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。
• 决策树的构造:
 决策树的构造就是通过属性选择来确定拓扑结构。
 构造决策树的关键步骤是分裂属性。所谓分裂属性就是在某个节点处按照某一特征属性的不同划分构造不同的分支,其目标是让各个分裂子集尽可能地“纯”。尽可能“纯”就是尽量让一个分裂子集中待分类项属于同一类别。分裂属性分为三种不同的情况:
 1、属性是离散值且不要求生成二叉决策树。此时用属性的每一个划分作为一个分支。
 2、属性是离散值且要求生成二叉决策树。此时使用属性划分的一个子集进行测试,按照“属于此子集”和“不属于此子集”分成两个分支。

 3、属性是连续值。此时确定一个值作为分裂点split_point,按照>split_point<=split_point生成两个分支。

决策树算法很多,先说下ID3和C4.5算法。

ID3
ID3算法使用信息增益(IG)或者熵(Entropy)值来确定使用哪个属性进行判定,作为最佳属性。

计算熵的公式为:

信息增益的公式如下:

缺点: 

1:ID3算法在选择根节点和内部节点中的分支属性时,采用信息增益作为评价标准。信息增益的缺点是倾向于选择取值较多的属性,在有些情况下这类属性可能不会提供太多有价值的信息。 例如ID字段等。

2:ID3算法只能对描述属性为离散型属性的数据集构造决策树 

 
C4.5
C4.5算法采用信息增益率作为选择分支属性的标准,并克服了ID3算法中信息增益选择属性时偏向选择取值多的属性的不足,并能够完成对连续属性离散化是处理;还能够对不完整数据进行处理。C4.5算法属于基于信息论Information Theory的方法,以信息论为基础,以信息熵和信息增益度为衡量标准,从而实现对数据的归纳和分类。 
改进:

1:可以处理连续数值型属性,对于离散值,C4.5和ID3的处理方法相同,对于某个属性的值连续时,假设这这个节点上的数据集合样本为total,C4.5算法进行如下处理: 

  • 将样本数据该属性A上的具体数值按照升序排列,得到属性序列值:{A1,A2,A3,...,Atotal}
  • 在上一步生成的序列值中生成total-1个分割点。第i个分割点的取值为Ai和Ai+1的均值,每个分割点都将属性序列划分为两个子集。
  • 计算每个分割点的信息增益(Information Gain),得到total-1个信息增益。
  • 对分裂点的信息增益进行修正:减去log2(N-1)/|D|,其中N为可能的分裂点个数,D为数据集合大小。
  • 选择修正后的信息增益值最大的分类点作为该属性的最佳分类点
  • 计算最佳分裂点的信息增益率(Gain Ratio)作为该属性的Gain Ratio
  • 选择Gain Ratio最大的属性作为分类属性。

2:用信息增益率(Information Gain Ratio)来选择属性 ,克服了用信息增益来选择属性时偏向选择值多的属性的不足。信息增益率定义为: 

其中Gain(S,A)和ID3算法中的信息增益计算相同,而SplitInfo(S,A)代表了按照属性A分裂样本集合S的广度和均匀性。

其中Si表示根据属性A分割S而成的样本子集。

3:后剪枝策略 

Decision Tree很容易产生Overfitting,剪枝能够避免树高度无限制增长,避免过度拟合数据。剪枝算法那是相当复杂

4:缺失值处理 

对于某些采样数据,可能会缺少属性值。在这种情况下,处理缺少属性值的通常做法是赋予该属性的常见值,或者属性均值。另外一种比较好的方法是为该属性的每个可能值赋予一个概率,即将该属性以概率形式赋值。例如给定Boolean属性B,已知采样数据有12个B=0和88个B=1实例,那么在赋值过程中,B属性的缺失值被赋为B(0)=0.12、B(1)=0.88;所以属性B的缺失值以12%概率被分到False的分支,以88%概率被分到True的分支。这种处理的目的是计算信息增益,使得这种属性值缺失的样本也能处理。


缺点: 

1:算法低效,在构造树的过程中,需要对数据集进行多次的顺序扫描和排序,因而导致算法的低效,尤其是在大量特征属性的数据集中。

2:内存受限,适合于能够驻留于内存的数据集,当训练集大得无法在内存容纳时程序无法运行。

 

下面用Java代码来实现这两个算法:

抽象类定义一个构建决策树的整体过程,子类只负责具体算法的实现。

@Override
public Object build(Data data) {
	//获取数据集的分类分割集合
	Map<Object, List<Instance>> splits = data.getSplits();
	//如果只有一个样本,将该样本所属分类作为新样本的分类
	if (splits.size() == 1) {
		return splits.keySet().iterator().next();
	}
	String[] attributes = data.getAttributes();
	// 如果没有供决策的属性,则将样本集中具有最多样本的分类作为新样本的分类,即投票选举出最多个数分类
	if (attributes.length == 0) {
		return obtainMaxCategory(splits);
	}
	// 选取最优最佳属性信息,交由子类去实现各个算法
	BestAttribute bestAttribute = chooseBestAttribute(data);
	// 决策树根结点,分支属性为选取的分割属性
	int bestAttrIndex = bestAttribute.getIndex();
	if (bestAttrIndex == -1) {
		return obtainMaxCategory(splits);
	}
	TreeNode treeNode = new TreeNode(attributes[bestAttrIndex]);
	// 已用过的测试属性不应再次被选为分割属性
	String[] subAttributes = new String[attributes.length - 1];
	for (int i = 0, j = 0; i < attributes.length; i++) {
		if (i != bestAttrIndex) {
			subAttributes[j++] = attributes[i];
		}
	}
	// 根据分支属性生成分支分裂信息
	Map<Object, Map<Object, List<Instance>>> subSplits = bestAttribute.getSplits();
	for (Entry<Object, Map<Object, List<Instance>>> entry : subSplits.entrySet()) {
		Object attrValue = entry.getKey();
		Data subData = new Data(subAttributes, entry.getValue());
		Object child = build(subData);
		treeNode.setChild(attrValue, child);
	}
	return treeNode;
}
public abstract BestAttribute chooseBestAttribute(Data data)

ID3算法的具体实现

@Override
public BestAttribute chooseBestAttribute(Data data) {
	Map<Object, List<Instance>> splits = data.getSplits();
	String[] attributes = data.getAttributes();
	int optIndex = -1; // 最优属性下标
	double minValue = Double.MAX_VALUE; // 最小信息量或说是期望
	Map<Object, Map<Object, List<Instance>>> optSplits = null; // 最优分支方案
	// 对每一个属性,计算将其作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和,选取最小为最优
	for (int attrIndex = 0; attrIndex < attributes.length; attrIndex++) {
		int allCount = 0; // 统计样本总数的计数器
		//按当前属性构建分裂信息:属性值->(分类->样本列表)
		Map<Object, Map<Object, List<Instance>>> curSplits = 
				new HashMap<Object, Map<Object, List<Instance>>>();
		for (Entry<Object, List<Instance>> entry : splits.entrySet()) {
			Object category = entry.getKey();
			List<Instance> instances = entry.getValue();
			for (Instance instance : instances) {
				Object attrValue = instance.getAttribute(attributes[attrIndex]);
				Map<Object, List<Instance>> split = curSplits.get(attrValue);
				if (split == null) {
					split = new HashMap<Object, List<Instance>>();
					curSplits.put(attrValue, split);
				}
				List<Instance> splitInstances = split.get(category);
				if (splitInstances == null) {
					splitInstances = new LinkedList<Instance>();
					split.put(category, splitInstances);
				}
				splitInstances.add(instance);
			}
			allCount += instances.size();
		}
		// 计算将当前属性作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和
		double curValue = 0.0; // 计数器:累加各分支
		for (Map<Object, List<Instance>> curSplit : curSplits.values()) {
			double perSplitCount = 0;
			for (List<Instance> list : curSplit.values())
				perSplitCount += list.size();
			// 累计当前分支样本数
			double perSplitValue = 0.0; // 计数器:当前分支
			for (List<Instance> list : curSplit.values()) {
				double p = list.size() / perSplitCount;
				perSplitValue -= p * (Math.log(p) / Math.log(2));
			}
			curValue += (perSplitCount / allCount) * perSplitValue;
		}
		// 选取最小为最优
		if (minValue > curValue) {
			optIndex = attrIndex;
			minValue = curValue;
			optSplits = curSplits;
		}
	}
	return new BestAttribute(optIndex, minValue, optSplits);
}

C4.5算法具体实现

@Override
public BestAttribute chooseBestAttribute(Data data) {
	Map<Object, List<Instance>> splits = data.getSplits();
	String[] attributes = data.getAttributes();
	int optIndex = -1; // 最优属性下标
	double maxGainRatio = 0.0; // 最大增益率
	Map<Object, Map<Object, List<Instance>>> optSplits = null; // 最优分支方案
	// 对每一个属性,计算将其作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和,选取最优
	for (int attrIndex = 0; attrIndex < attributes.length; attrIndex++) {
		int allCount = 0; // 统计样本总数的计数器
		// 按当前属性构建Map:属性值->(分类->样本列表)
		Map<Object, Map<Object, List<Instance>>> curSplits = 
				new HashMap<Object, Map<Object, List<Instance>>>();
		for (Entry<Object, List<Instance>> entry : splits.entrySet()) {
			Object category = entry.getKey();
			List<Instance> instances = entry.getValue();
			for (Instance instance : instances) {
				Object attrValue = instance.getAttribute(attributes[attrIndex]);
				Map<Object, List<Instance>> split = curSplits.get(attrValue);
				if (split == null) {
					split = new HashMap<Object, List<Instance>>();
					curSplits.put(attrValue, split);
				}
				List<Instance> splitInstances = split.get(category);
				if (splitInstances == null) {
					splitInstances = new LinkedList<Instance>();
					split.put(category, splitInstances);
				}
				splitInstances.add(instance);
			}
			allCount += instances.size();
		}
		// 计算将当前属性作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和
		double curValue = 0.0; // 计数器:累加各分支
		double splitInfo = 0.0; //分裂信息
		for (Map<Object, List<Instance>> curSplit : curSplits.values()) {
			double perSplitCount = 0;
			for (List<Instance> list : curSplit.values())
				perSplitCount += list.size();
			// 累计当前分支样本数
			double perSplitValue = 0.0; // 计数器:当前分支
			for (List<Instance> list : curSplit.values()) {
				double p = list.size() / perSplitCount;
				perSplitValue -= p * (Math.log(p) / Math.log(2));
			}
			double dj = (perSplitCount / allCount);
			curValue += dj * perSplitValue;
			splitInfo -= dj * (Math.log(dj) / Math.log(2));
		}
		double gainRatio = curValue / splitInfo;
		if (maxGainRatio <= gainRatio) {
			optIndex = attrIndex;
			maxGainRatio = gainRatio;
			optSplits = curSplits;
		}
	}
	return new BestAttribute(optIndex, maxGainRatio, optSplits);
}
测试数据可以参照以下格式处理:
1 no age:youth income:high student:no credit:fair  
2 no age:youth income:high student:no credit:excellent  
3 yes age:middle_aged income:high student:no credit:fair  
4 yes age:senior income:medium student:no credit:fair  
5 yes age:senior income:low student:yes credit:fair  
6 no age:senior income:low student:yes credit:excellent  
7 yes age:middle_aged income:low student:yes credit:excellent
8 no age:youth income:medium student:no credit:fair  
9 yes age:youth income:low student:yes credit:fair  
10 yes age:senior income:medium student:yes credit:fair  
11 yes age:youth income:medium student:yes credit:excellent  
12 yes age:middle_aged income:medium student:no credit:excellent  
13 yes age:middle_aged income:high student:yes credit:fair  
14 no age:senior income:medium student:no credit:excellent 
15 no age:senior income:high student:no credit:fair 
测试程序:
Builder treeBuilder = new DecisionTreeID3Builder();
//Builder treeBuilder = new DecisionTreeC45Builder();
String trainFilePath = "d:\\trainset.txt";
String testFilePath = "d:\\trainset.txt";
Data data = DataLoader.load(trainFilePath);
TreeNode treeNode = (TreeNode) treeBuilder.build(data);
Data testData = DataLoader.load(testFilePath);
Object[] results = (Object[]) treeNode.classify(testData);
ShowUtils.print(results);

代码托管:https://github.com/fighting-one-piece/repository-datamining.git



你可能感兴趣的:(数据挖掘,分类,决策树,id3,C4.5)