之前一直做的都是J2EE,最近开始接触数据挖掘,特做笔记记录一下。第一次写东西,写的不好,望大家谅解。先上一些基础概念,大致了解下决策树这个东西:
• 3、属性是连续值。此时确定一个值作为分裂点split_point,按照>split_point和<=split_point生成两个分支。
决策树算法很多,先说下ID3和C4.5算法。
缺点:
1:ID3算法在选择根节点和内部节点中的分支属性时,采用信息增益作为评价标准。信息增益的缺点是倾向于选择取值较多的属性,在有些情况下这类属性可能不会提供太多有价值的信息。 例如ID字段等。
2:ID3算法只能对描述属性为离散型属性的数据集构造决策树
1:可以处理连续数值型属性,对于离散值,C4.5和ID3的处理方法相同,对于某个属性的值连续时,假设这这个节点上的数据集合样本为total,C4.5算法进行如下处理:
2:用信息增益率(Information Gain Ratio)来选择属性 ,克服了用信息增益来选择属性时偏向选择值多的属性的不足。信息增益率定义为:
其中Gain(S,A)和ID3算法中的信息增益计算相同,而SplitInfo(S,A)代表了按照属性A分裂样本集合S的广度和均匀性。
其中Si表示根据属性A分割S而成的样本子集。
3:后剪枝策略
Decision Tree很容易产生Overfitting,剪枝能够避免树高度无限制增长,避免过度拟合数据。剪枝算法那是相当复杂
4:缺失值处理
对于某些采样数据,可能会缺少属性值。在这种情况下,处理缺少属性值的通常做法是赋予该属性的常见值,或者属性均值。另外一种比较好的方法是为该属性的每个可能值赋予一个概率,即将该属性以概率形式赋值。例如给定Boolean属性B,已知采样数据有12个B=0和88个B=1实例,那么在赋值过程中,B属性的缺失值被赋为B(0)=0.12、B(1)=0.88;所以属性B的缺失值以12%概率被分到False的分支,以88%概率被分到True的分支。这种处理的目的是计算信息增益,使得这种属性值缺失的样本也能处理。
1:算法低效,在构造树的过程中,需要对数据集进行多次的顺序扫描和排序,因而导致算法的低效,尤其是在大量特征属性的数据集中。
2:内存受限,适合于能够驻留于内存的数据集,当训练集大得无法在内存容纳时程序无法运行。
下面用Java代码来实现这两个算法:
抽象类定义一个构建决策树的整体过程,子类只负责具体算法的实现。
@Override public Object build(Data data) { //获取数据集的分类分割集合 Map<Object, List<Instance>> splits = data.getSplits(); //如果只有一个样本,将该样本所属分类作为新样本的分类 if (splits.size() == 1) { return splits.keySet().iterator().next(); } String[] attributes = data.getAttributes(); // 如果没有供决策的属性,则将样本集中具有最多样本的分类作为新样本的分类,即投票选举出最多个数分类 if (attributes.length == 0) { return obtainMaxCategory(splits); } // 选取最优最佳属性信息,交由子类去实现各个算法 BestAttribute bestAttribute = chooseBestAttribute(data); // 决策树根结点,分支属性为选取的分割属性 int bestAttrIndex = bestAttribute.getIndex(); if (bestAttrIndex == -1) { return obtainMaxCategory(splits); } TreeNode treeNode = new TreeNode(attributes[bestAttrIndex]); // 已用过的测试属性不应再次被选为分割属性 String[] subAttributes = new String[attributes.length - 1]; for (int i = 0, j = 0; i < attributes.length; i++) { if (i != bestAttrIndex) { subAttributes[j++] = attributes[i]; } } // 根据分支属性生成分支分裂信息 Map<Object, Map<Object, List<Instance>>> subSplits = bestAttribute.getSplits(); for (Entry<Object, Map<Object, List<Instance>>> entry : subSplits.entrySet()) { Object attrValue = entry.getKey(); Data subData = new Data(subAttributes, entry.getValue()); Object child = build(subData); treeNode.setChild(attrValue, child); } return treeNode; }
public abstract BestAttribute chooseBestAttribute(Data data)
ID3算法的具体实现
@Override public BestAttribute chooseBestAttribute(Data data) { Map<Object, List<Instance>> splits = data.getSplits(); String[] attributes = data.getAttributes(); int optIndex = -1; // 最优属性下标 double minValue = Double.MAX_VALUE; // 最小信息量或说是期望 Map<Object, Map<Object, List<Instance>>> optSplits = null; // 最优分支方案 // 对每一个属性,计算将其作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和,选取最小为最优 for (int attrIndex = 0; attrIndex < attributes.length; attrIndex++) { int allCount = 0; // 统计样本总数的计数器 //按当前属性构建分裂信息:属性值->(分类->样本列表) Map<Object, Map<Object, List<Instance>>> curSplits = new HashMap<Object, Map<Object, List<Instance>>>(); for (Entry<Object, List<Instance>> entry : splits.entrySet()) { Object category = entry.getKey(); List<Instance> instances = entry.getValue(); for (Instance instance : instances) { Object attrValue = instance.getAttribute(attributes[attrIndex]); Map<Object, List<Instance>> split = curSplits.get(attrValue); if (split == null) { split = new HashMap<Object, List<Instance>>(); curSplits.put(attrValue, split); } List<Instance> splitInstances = split.get(category); if (splitInstances == null) { splitInstances = new LinkedList<Instance>(); split.put(category, splitInstances); } splitInstances.add(instance); } allCount += instances.size(); } // 计算将当前属性作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和 double curValue = 0.0; // 计数器:累加各分支 for (Map<Object, List<Instance>> curSplit : curSplits.values()) { double perSplitCount = 0; for (List<Instance> list : curSplit.values()) perSplitCount += list.size(); // 累计当前分支样本数 double perSplitValue = 0.0; // 计数器:当前分支 for (List<Instance> list : curSplit.values()) { double p = list.size() / perSplitCount; perSplitValue -= p * (Math.log(p) / Math.log(2)); } curValue += (perSplitCount / allCount) * perSplitValue; } // 选取最小为最优 if (minValue > curValue) { optIndex = attrIndex; minValue = curValue; optSplits = curSplits; } } return new BestAttribute(optIndex, minValue, optSplits); }
C4.5算法具体实现
@Override public BestAttribute chooseBestAttribute(Data data) { Map<Object, List<Instance>> splits = data.getSplits(); String[] attributes = data.getAttributes(); int optIndex = -1; // 最优属性下标 double maxGainRatio = 0.0; // 最大增益率 Map<Object, Map<Object, List<Instance>>> optSplits = null; // 最优分支方案 // 对每一个属性,计算将其作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和,选取最优 for (int attrIndex = 0; attrIndex < attributes.length; attrIndex++) { int allCount = 0; // 统计样本总数的计数器 // 按当前属性构建Map:属性值->(分类->样本列表) Map<Object, Map<Object, List<Instance>>> curSplits = new HashMap<Object, Map<Object, List<Instance>>>(); for (Entry<Object, List<Instance>> entry : splits.entrySet()) { Object category = entry.getKey(); List<Instance> instances = entry.getValue(); for (Instance instance : instances) { Object attrValue = instance.getAttribute(attributes[attrIndex]); Map<Object, List<Instance>> split = curSplits.get(attrValue); if (split == null) { split = new HashMap<Object, List<Instance>>(); curSplits.put(attrValue, split); } List<Instance> splitInstances = split.get(category); if (splitInstances == null) { splitInstances = new LinkedList<Instance>(); split.put(category, splitInstances); } splitInstances.add(instance); } allCount += instances.size(); } // 计算将当前属性作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和 double curValue = 0.0; // 计数器:累加各分支 double splitInfo = 0.0; //分裂信息 for (Map<Object, List<Instance>> curSplit : curSplits.values()) { double perSplitCount = 0; for (List<Instance> list : curSplit.values()) perSplitCount += list.size(); // 累计当前分支样本数 double perSplitValue = 0.0; // 计数器:当前分支 for (List<Instance> list : curSplit.values()) { double p = list.size() / perSplitCount; perSplitValue -= p * (Math.log(p) / Math.log(2)); } double dj = (perSplitCount / allCount); curValue += dj * perSplitValue; splitInfo -= dj * (Math.log(dj) / Math.log(2)); } double gainRatio = curValue / splitInfo; if (maxGainRatio <= gainRatio) { optIndex = attrIndex; maxGainRatio = gainRatio; optSplits = curSplits; } } return new BestAttribute(optIndex, maxGainRatio, optSplits); }测试数据可以参照以下格式处理:
1 no age:youth income:high student:no credit:fair 2 no age:youth income:high student:no credit:excellent 3 yes age:middle_aged income:high student:no credit:fair 4 yes age:senior income:medium student:no credit:fair 5 yes age:senior income:low student:yes credit:fair 6 no age:senior income:low student:yes credit:excellent 7 yes age:middle_aged income:low student:yes credit:excellent 8 no age:youth income:medium student:no credit:fair 9 yes age:youth income:low student:yes credit:fair 10 yes age:senior income:medium student:yes credit:fair 11 yes age:youth income:medium student:yes credit:excellent 12 yes age:middle_aged income:medium student:no credit:excellent 13 yes age:middle_aged income:high student:yes credit:fair 14 no age:senior income:medium student:no credit:excellent 15 no age:senior income:high student:no credit:fair测试程序:
Builder treeBuilder = new DecisionTreeID3Builder(); //Builder treeBuilder = new DecisionTreeC45Builder(); String trainFilePath = "d:\\trainset.txt"; String testFilePath = "d:\\trainset.txt"; Data data = DataLoader.load(trainFilePath); TreeNode treeNode = (TreeNode) treeBuilder.build(data); Data testData = DataLoader.load(testFilePath); Object[] results = (Object[]) treeNode.classify(testData); ShowUtils.print(results);