预备知识:决策树
初识ID3
回顾决策树的基本知识,其构建过程主要有下述三个重要的问题:
(1)数据是怎么分裂的
(2)如何选择分类的属性
(3)什么时候停止分裂
从上述三个问题出发,以实际的例子对ID3算法进行阐述。
例:通过当天的天气、温度、湿度和季节预测明天的天气
表1 原始数据
当天天气 |
温度 |
湿度 |
季节 |
明天天气 |
晴 |
25 |
50 |
春天 |
晴 |
阴 |
21 |
48 |
春天 |
阴 |
阴 |
18 |
70 |
春天 |
雨 |
晴 |
28 |
41 |
夏天 |
晴 |
雨 |
8 |
65 |
冬天 |
阴 |
晴 |
18 |
43 |
夏天 |
晴 |
阴 |
24 |
56 |
秋天 |
晴 |
雨 |
18 |
76 |
秋天 |
阴 |
雨 |
31 |
61 |
夏天 |
晴 |
阴 |
6 |
43 |
冬天 |
雨 |
晴 |
15 |
55 |
秋天 |
阴 |
雨 |
4 |
58 |
冬天 |
雨 |
1.数据分割
对于离散型数据,直接按照离散数据的取值进行分裂,每一个取值对应一个子节点,以“当前天气”为例对数据进行分割,如图1所示。
对于连续型数据,ID3原本是没有处理能力的,只有通过离散化将连续性数据转化成离散型数据再进行处理。
连续数据离散化是另外一个课题,本文不深入阐述,这里直接采用等距离数据划分的李算话方法。该方法先对数据进行排序,然后将连续型数据划分为多个区间,并使每一个区间的数据量基本相同,以温度为例对数据进行分割,如图2所示。
2. 选择最优分裂属性
ID3采用信息增益作为选择最优的分裂属性的方法,选择熵作为衡量节点纯度的标准,信息增益的计算公式如下:
其中, 表示父节点的熵; 表示节点i的熵,熵越大,节点的信息量越多,越不纯; 表示子节点i的数据量与父节点数据量之比。 越大,表示分裂后的熵越小,子节点变得越纯,分类的效果越好,因此选择 最大的属性作为分裂属性。
对上述的例子的跟节点进行分裂,分别计算每一个属性的信息增益,选择信息增益最大的属性进行分裂。
天气属性:(数据分割如上图1所示)
温度:(数据分割如上图2所示)
湿度:
季节:
由于最大,所以选择属性“季节”作为根节点的分裂属性。
3.停止分裂的条件
停止分裂的条件已经在决策树中阐述,这里不再进行阐述。
(1)最小节点数
当节点的数据量小于一个指定的数量时,不继续分裂。两个原因:一是数据量较少时,再做分裂容易强化噪声数据的作用;二是降低树生长的复杂性。提前结束分裂一定程度上有利于降低过拟合的影响。
(2)熵或者基尼值小于阀值。
由上述可知,熵和基尼值的大小表示数据的复杂程度,当熵或者基尼值过小时,表示数据的纯度比较大,如果熵或者基尼值小于一定程度时,节点停止分裂。
(3)决策树的深度达到指定的条件
节点的深度可以理解为节点与决策树跟节点的距离,如根节点的子节点的深度为1,因为这些节点与跟节点的距离为1,子节点的深度要比父节点的深度大1。决策树的深度是所有叶子节点的最大深度,当深度到达指定的上限大小时,停止分裂。
(4)所有特征已经使用完毕,不能继续进行分裂。
被动式停止分裂的条件,当已经没有可分的属性时,直接将当前节点设置为叶子节点。
程序设计及源代码(C#版本)
(1)数据处理
用二维数组存储原始的数据,每一行表示一条记录,前n-1列表示数据的属性,第n列表示分类的标签。
static double[][] allData;
为了方便后面的处理,对离散属性进行数字化处理,将离散值表示成数字,并用一个链表数组进行存储,数组的第一个元素表示属性1的离散值。
static List<String>[] featureValues;
那么经过处理后的表1数据可以转化为如表2所示的数据:
表2 初始化后的数据
当天天气 |
温度 |
湿度 |
季节 |
明天天气 |
1 |
25 |
50 |
1 |
1 |
2 |
21 |
48 |
1 |
2 |
2 |
18 |
70 |
1 |
3 |
1 |
28 |
41 |
2 |
1 |
3 |
8 |
65 |
3 |
2 |
1 |
18 |
43 |
2 |
1 |
2 |
24 |
56 |
4 |
1 |
3 |
18 |
76 |
4 |
2 |
3 |
31 |
61 |
2 |
1 |
2 |
6 |
43 |
3 |
3 |
1 |
15 |
55 |
4 |
2 |
3 |
4 |
58 |
3 |
3 |
其中,对于当天天气属性,数字{1,2,3}分别表示{晴,阴,雨};对于季节属性{1,2,3,4}分别表示{春天、夏天、冬天、秋天};对于明天天气{1,2,3}分别表示{晴、阴、雨}。
(2)两个类:节点类和分裂信息
a)节点类Node
该类存储了节点的信息,包括节点的数据量、节点选择的分裂属性、节点输出类、子节点的个数、子节点的分类误差等。
1 class Node 2 { 3 /// <summary> 4 /// 各个子节点的取值 5 /// </summary> 6 public List<String> features { get; set; } 7 /// <summary> 8 /// 分裂属性的类型 9 /// </summary> 10 public String feature_Type { get; set; } 11 /// <summary> 12 /// 分裂的属性 13 /// </summary> 14 public String SplitFeature { get; set; } 15 /// <summary> 16 /// 节点对应各个分类的数目 17 /// </summary> 18 public double[] ClassCount { get; set; } 19 /// <summary> 20 /// 各个孩子节点 21 /// </summary> 22 public List<Node> childNodes { get; set; } 23 /// <summary> 24 /// 父亲节点(未用到) 25 /// </summary> 26 public Node Parent { get; set; } 27 /// <summary> 28 /// 占比最大的类别 29 /// </summary> 30 public String finalResult { get; set; } 31 /// <summary> 32 /// 数的深度 33 /// </summary> 34 public int deep { get; set; } 35 /// <summary> 36 /// 该节点占比最大的类标号 37 /// </summary> 38 public int result { get; set; } 39 /// <summary> 40 /// 节点的数量 41 /// </summary> 42 public int rowCount{ get; set; } 43 44 45 public void setClassCount(double[] count) 46 { 47 this.ClassCount = count; 48 double max = ClassCount[0]; 49 int result = 0; 50 for (int i = 1; i < ClassCount.Length; i++) 51 { 52 if (max < ClassCount[i]) 53 { 54 max = ClassCount[i]; 55 result = i; 56 } 57 } 58 //wrong = Convert.ToInt32(nums.Count - ClassCount[result]); 59 this.result = result; 60 } 61 }
b)分裂信息类SplitInfo
该类存储节点进行分裂的信息,包括各个子节点的行坐标、子节点各个类的数目、该节点分裂的属性、属性的类型等。
1 class SplitInfo 2 { 3 /// <summary> 4 /// 分裂的列下标 5 /// </summary> 6 public int splitIndex { get; set; } 7 /// <summary> 8 /// 数据类型 9 /// </summary> 10 public int type { get; set; } 11 /// <summary> 12 /// 分裂属性的取值 13 /// </summary> 14 public List<String> features { get; set; } 15 /// <summary> 16 /// 各个节点的行坐标链表 17 /// </summary> 18 public List<int>[] temp { get; set; } 19 /// <summary> 20 /// 每个节点各类的数目 21 /// </summary> 22 public double[][] class_Count { get; set; } 23 }
(3)节点分裂方法findBestSplit(Node node,List<int> nums,int[] isUsed),该方法对节点进行分裂,返回值Node
其中:
node表示即将进行分裂的节点;
nums表示节点数据对应的行坐标列表;
isUsed表示到该节点位置所有属性的使用情况(1:表示该属性不能再次使用,0:表示该属性可以使用);
findBestSplit主要有以下几个组成部分:
1)节点分裂停止的判定
判断节点是否需要继续分裂,分裂判断条件如上文所述。源代码如下:
1 public static Object[] ifEnd(Node node, double entropy,int[] isUsed) 2 { 3 try 4 { 5 double[] count = node.ClassCount; 6 int rowCount = node.rowCount; 7 int maxResult = 0; 8 double maxRate = 0; 9 #region 数达到某一深度 10 int deep = node.deep; 11 if (deep >= maxDeep) 12 { 13 maxResult = node.result + 1; 14 node.feature_Type=("result"); 15 node.features=(new List<String>() { maxResult + "" }); 16 return new Object[] { true, node }; 17 } 18 #endregion 19 #region 纯度(其实跟后面的有点重了,记得要修改) 20 //maxResult = 1; 21 //for (int i = 1; i < count.Length; i++) 22 //{ 23 // if (count[i] / rowCount >= 0.95) 24 // { 25 // node.setFeatureType("result"); 26 // node.setFeatures(new List<String> { "" + (i + 1) }); 27 // return new Object[] { true, node }; 28 // } 29 //} 30 //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1])); 31 #endregion 32 #region 熵为0 33 if (entropy == 0) 34 { 35 maxRate = count[0] / rowCount; 36 maxResult = 1; 37 for (int i = 1; i < count.Length; i++) 38 { 39 if (count[i] / rowCount >= maxRate) 40 { 41 maxRate = count[i] / rowCount; 42 maxResult = i + 1; 43 } 44 } 45 node.feature_Type=("result"); 46 node.features=(new List<String> { maxResult + "" }); 47 return new Object[] { true, node }; 48 } 49 #endregion 50 #region 属性已经分完 51 //int[] isUsed = node.; 52 bool flag = true; 53 for (int i = 0; i < isUsed.Length - 1; i++) 54 { 55 if (isUsed[i] == 0) 56 { 57 flag = false; 58 break; 59 } 60 } 61 if (flag) 62 { 63 maxRate = count[0] / rowCount; 64 maxResult = 1; 65 for (int i = 1; i < count.Length; i++) 66 { 67 if (count[i] / rowCount >= maxRate) 68 { 69 maxRate = count[i] / rowCount; 70 maxResult = i + 1; 71 } 72 } 73 node.feature_Type=("result"); 74 node.features=(new List<String> { "" + (maxResult) }); 75 return new Object[] { true, node }; 76 } 77 #endregion 78 #region 数据量少于100 79 if (rowCount < Limit_Node) 80 { 81 maxRate = count[0] / rowCount; 82 maxResult = 1; 83 for (int i = 1; i < count.Length; i++) 84 { 85 if (count[i] / rowCount >= maxRate) 86 { 87 maxRate = count[i] / rowCount; 88 maxResult = i + 1; 89 } 90 } 91 node.feature_Type=("result"); 92 node.features=(new List<String> { "" + (maxResult) }); 93 return new Object[] { true, node }; 94 } 95 #endregion 96 return new Object[] { false, node }; 97 } 98 catch (Exception e) 99 { 100 return new Object[] { false, node }; 101 } 102 }
2)寻找最优的分裂属性
寻找最优的分裂属性需要计算每一个分裂属性分裂后的信息增益,计算公式上文已给出,其中熵的计算代码如下:
1 public static double CalEntropy(double[] counts, int countAll) 2 { 3 try 4 { 5 double allShang = 0; 6 for (int i = 0; i < counts.Length; i++) 7 { 8 if (counts[i] == 0) 9 { 10 continue; 11 } 12 double rate = counts[i] / countAll; 13 allShang = allShang + rate * Math.Log(rate, 2); 14 } 15 return -allShang; 16 } 17 catch (Exception e) 18 { 19 return 0; 20 } 21 }
3)进行分裂,同时子节点也执行相同的分类步骤
其实就是递归的过程,对每一个子节点执行findBestSplit方法进行分裂。
全部源代码:
1 #region ID3核心算法 2 /// <summary> 3 /// 测试 4 /// </summary> 5 /// <param name="node"></param> 6 /// <param name="data"></param> 7 public static String findResult(Node node, String[] data) 8 { 9 List<String> featrues = node.features; 10 String type = node.feature_Type; 11 if (type == "result") 12 { 13 return featrues[0]; 14 } 15 int split = Convert.ToInt32(node.SplitFeature); 16 List<Node> childNodes = node.childNodes; 17 double[] resultCount = node.ClassCount; 18 if (type == "连续") 19 { 20 21 22 for (int i = 0; i < featrues.Count; i++) 23 { 24 double value = Convert.ToDouble(featrues[i]); 25 if (Convert.ToDouble(data[split]) <= value) 26 { 27 return findResult(childNodes[i], data); 28 } 29 } 30 return findResult(childNodes[featrues.Count], data); 31 } 32 else 33 { 34 for (int i = 0; i < featrues.Count; i++) 35 { 36 if (data[split] == featrues[i]) 37 { 38 return findResult(childNodes[i], data); 39 } 40 if (i == featrues.Count - 1) 41 { 42 double count = resultCount[0]; 43 int maxInt = 0; 44 for (int j = 1; j < resultCount.Length; j++) 45 { 46 if (count < resultCount[j]) 47 { 48 count = resultCount[j]; 49 maxInt = j; 50 } 51 } 52 return findResult(childNodes[0], data); 53 } 54 } 55 } 56 return null; 57 } 58 /// <summary> 59 /// 判断是否还需要分裂 60 /// </summary> 61 /// <param name="node"></param> 62 /// <returns></returns> 63 public static Object[] ifEnd(Node node, double entropy,int[] isUsed) 64 { 65 try 66 { 67 double[] count = node.ClassCount; 68 int rowCount = node.rowCount; 69 int maxResult = 0; 70 double maxRate = 0; 71 #region 数达到某一深度 72 int deep = node.deep; 73 if (deep >= maxDeep) 74 { 75 maxResult = node.result + 1; 76 node.feature_Type=("result"); 77 node.features=(new List<String>() { maxResult + "" }); 78 return new Object[] { true, node }; 79 } 80 #endregion 81 #region 纯度(其实跟后面的有点重了,记得要修改) 82 //maxResult = 1; 83 //for (int i = 1; i < count.Length; i++) 84 //{ 85 // if (count[i] / rowCount >= 0.95) 86 // { 87 // node.setFeatureType("result"); 88 // node.setFeatures(new List<String> { "" + (i + 1) }); 89 // return new Object[] { true, node }; 90 // } 91 //} 92 //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1])); 93 #endregion 94 #region 熵为0 95 if (entropy == 0) 96 { 97 maxRate = count[0] / rowCount; 98 maxResult = 1; 99 for (int i = 1; i < count.Length; i++) 100 { 101 if (count[i] / rowCount >= maxRate) 102 { 103 maxRate = count[i] / rowCount; 104 maxResult = i + 1; 105 } 106 } 107 node.feature_Type=("result"); 108 node.features=(new List<String> { maxResult + "" }); 109 return new Object[] { true, node }; 110 } 111 #endregion 112 #region 属性已经分完 113 //int[] isUsed = node.; 114 bool flag = true; 115 for (int i = 0; i < isUsed.Length - 1; i++) 116 { 117 if (isUsed[i] == 0) 118 { 119 flag = false; 120 break; 121 } 122 } 123 if (flag) 124 { 125 maxRate = count[0] / rowCount; 126 maxResult = 1; 127 for (int i = 1; i < count.Length; i++) 128 { 129 if (count[i] / rowCount >= maxRate) 130 { 131 maxRate = count[i] / rowCount; 132 maxResult = i + 1; 133 } 134 } 135 node.feature_Type=("result"); 136 node.features=(new List<String> { "" + (maxResult) }); 137 return new Object[] { true, node }; 138 } 139 #endregion 140 #region 数据量少于100 141 if (rowCount < Limit_Node) 142 { 143 maxRate = count[0] / rowCount; 144 maxResult = 1; 145 for (int i = 1; i < count.Length; i++) 146 { 147 if (count[i] / rowCount >= maxRate) 148 { 149 maxRate = count[i] / rowCount; 150 maxResult = i + 1; 151 } 152 } 153 node.feature_Type=("result"); 154 node.features=(new List<String> { "" + (maxResult) }); 155 return new Object[] { true, node }; 156 } 157 #endregion 158 return new Object[] { false, node }; 159 } 160 catch (Exception e) 161 { 162 return new Object[] { false, node }; 163 } 164 } 165 #region 排序算法 166 public static void InsertSort(double[] values, List<int> arr, int StartIndex, int endIndex) 167 { 168 for (int i = StartIndex + 1; i <= endIndex; i++) 169 { 170 int key = arr[i]; 171 double init = values[i]; 172 int j = i - 1; 173 while (j >= StartIndex && values[j] > init) 174 { 175 arr[j + 1] = arr[j]; 176 values[j + 1] = values[j]; 177 j--; 178 } 179 arr[j + 1] = key; 180 values[j + 1] = init; 181 } 182 } 183 static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high) 184 { 185 int mid = low + ((high - low) >> 1);//计算数组中间的元素的下标 186 187 //使用三数取中法选择枢轴 188 if (values[mid] > values[high])//目标: arr[mid] <= arr[high] 189 { 190 swap(values, arr, mid, high); 191 } 192 if (values[low] > values[high])//目标: arr[low] <= arr[high] 193 { 194 swap(values, arr, low, high); 195 } 196 if (values[mid] > values[low]) //目标: arr[low] >= arr[mid] 197 { 198 swap(values, arr, mid, low); 199 } 200 //此时,arr[mid] <= arr[low] <= arr[high] 201 return low; 202 //low的位置上保存这三个位置中间的值 203 //分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了 204 } 205 static void swap(double[] values, List<int> arr, int t1, int t2) 206 { 207 double temp = values[t1]; 208 values[t1] = values[t2]; 209 values[t2] = temp; 210 int key = arr[t1]; 211 arr[t1] = arr[t2]; 212 arr[t2] = key; 213 } 214 static void QSort(double[] values, List<int> arr, int low, int high) 215 { 216 int first = low; 217 int last = high; 218 219 int left = low; 220 int right = high; 221 222 int leftLen = 0; 223 int rightLen = 0; 224 225 if (high - low + 1 < 10) 226 { 227 InsertSort(values, arr, low, high); 228 return; 229 } 230 231 //一次分割 232 int key = SelectPivotMedianOfThree(values, arr, low, high);//使用三数取中法选择枢轴 233 double inti = values[key]; 234 int currentKey = arr[key]; 235 236 while (low < high) 237 { 238 while (high > low && values[high] >= inti) 239 { 240 if (values[high] == inti)//处理相等元素 241 { 242 swap(values, arr, right, high); 243 right--; 244 rightLen++; 245 } 246 high--; 247 } 248 arr[low] = arr[high]; 249 values[low] = values[high]; 250 while (high > low && values[low] <= inti) 251 { 252 if (values[low] == inti) 253 { 254 swap(values, arr, left, low); 255 left++; 256 leftLen++; 257 } 258 low++; 259 } 260 arr[high] = arr[low]; 261 values[high] = values[low]; 262 } 263 arr[low] = currentKey; 264 values[low] = values[key]; 265 //一次快排结束 266 //把与枢轴key相同的元素移到枢轴最终位置周围 267 int i = low - 1; 268 int j = first; 269 while (j < left && values[i] != inti) 270 { 271 swap(values, arr, i, j); 272 i--; 273 j++; 274 } 275 i = low + 1; 276 j = last; 277 while (j > right && values[i] != inti) 278 { 279 swap(values, arr, i, j); 280 i++; 281 j--; 282 } 283 QSort(values, arr, first, low - 1 - leftLen); 284 QSort(values, arr, low + 1 + rightLen, last); 285 } 286 #endregion 287 /// <summary> 288 /// 寻找最佳的分裂点 289 /// </summary> 290 /// <param name="num"></param> 291 /// <param name="node"></param> 292 public static Node findBestSplit(Node node, int lastCol,List<int> nums,int[] isUsed) 293 { 294 try 295 { 296 //判断是否继续分裂 297 double totalShang = CalEntropy(node.ClassCount, nums.Count); 298 Object[] check = ifEnd(node, totalShang, isUsed); 299 if ((bool)check[0]) 300 { 301 node = (Node)check[1]; 302 return node; 303 } 304 #region 变量声明 305 SplitInfo info = new SplitInfo(); 306 //int[] isUsed = node.getUsed(); //连续变量or离散变量 307 //List<int> nums = node.getNum(); //样本的标号 308 int RowCount = nums.Count; //样本总数 309 double jubuMax = 0; //局部最大熵 310 #endregion 311 for (int i = 0; i < isUsed.Length - 1; i++) 312 { 313 if (isUsed[i] == 1) 314 { 315 continue; 316 } 317 #region 离散变量 318 if (type[i] == 0) 319 { 320 int[] allFeatureCount = new int[0]; //所有类别的数量 321 double[][] allCount = new double[allNum[i]][]; 322 for (int j = 0; j < allCount.Length; j++) 323 { 324 allCount[j] = new double[classCount]; 325 } 326 int[] countAllFeature = new int[allNum[i]]; 327 List<int>[] temp = new List<int>[allNum[i]]; 328 for (int j = 0; j < temp.Length; j++) 329 { 330 temp[j] = new List<int>(); 331 } 332 for (int j = 0; j < nums.Count; j++) 333 { 334 int index = Convert.ToInt32(allData[nums[j]][i]); 335 temp[index - 1].Add(nums[j]); 336 countAllFeature[index - 1]++; 337 allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++; 338 } 339 double allShang = 0; 340 for (int j = 0; j < allCount.Length; j++) 341 { 342 allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount; 343 } 344 allShang = (totalShang - allShang); 345 if (allShang > jubuMax) 346 { 347 info.features=new List<String>(); 348 info.type=0; 349 info.temp=(temp); 350 info.splitIndex=(i); 351 info.class_Count=(allCount); 352 jubuMax = allShang; 353 allFeatureCount = countAllFeature; 354 } 355 } 356 #endregion 357 #region 连续变量 358 else 359 { 360 double[] leftCount = new double[classCount]; //做节点各个类别的数量 361 double[] rightCount = new double[classCount]; //右节点各个类别的数量 362 double[] values = new double[nums.Count]; 363 List<String> List_Feature = new List<string>(); 364 for (int j = 0; j < values.Length; j++) 365 { 366 values[j] = allData[nums[j]][i]; 367 } 368 QSort(values, nums, 0, nums.Count - 1); 369 int eachNum = nums.Count / 5; 370 double lianxuMax = 0; //连续型属性的最大熵 371 int index = 1; 372 double[][] counts = new double[5][]; 373 List<int>[] temp = new List<int>[5]; 374 for (int j = 0; j < 5; j++) 375 { 376 counts[j] = new double[classCount]; 377 temp[j] = new List<int>(); 378 } 379 for (int j = 0; j < nums.Count - 1; j++) 380 { 381 if (j >= index * eachNum&&index<5) 382 { 383 List_Feature.Add(allData[nums[j]][i]+""); 384 lianxuMax += eachNum*CalEntropy(counts[index - 1], eachNum)/RowCount; 385 index++; 386 } 387 temp[index-1].Add(nums[j]); 388 counts[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1])-1]++; 389 } 390 lianxuMax += ((eachNum + nums.Count % 5)*CalEntropy(counts[index - 1], eachNum + nums.Count % 5) / RowCount); 391 lianxuMax = totalShang - lianxuMax; 392 if (lianxuMax > jubuMax) 393 { 394 info.splitIndex=(i); 395 info.features=(List_Feature); 396 info.type=(1); 397 jubuMax = lianxuMax; 398 info.temp=(temp); 399 info.class_Count=(counts); 400 } 401 } 402 #endregion 403 } 404 #region 如何找不到最佳的分裂属性,则设为叶节点 405 if (info.splitIndex == -1) 406 { 407 double[] finalCount = node.ClassCount; 408 double max = finalCount[0]; 409 int result = 1; 410 for (int i = 1; i < finalCount.Length; i++) 411 { 412 if (finalCount[i] > max) 413 { 414 max = finalCount[i]; 415 result = (i + 1); 416 } 417 } 418 node.feature_Type=("result"); 419 node.features=(new List<String> { "" + result }); 420 return node; 421 } 422 #endregion 423 int deep = node.deep; 424 #region 分裂 425 node.SplitFeature=("" + info.splitIndex); 426 427 List<Node> childNode = new List<Node>(); 428 int[] used = new int[isUsed.Length]; 429 for (int i = 0; i < used.Length; i++) 430 { 431 used[i] = isUsed[i]; 432 } 433 if (info.type == 0) 434 { 435 used[info.splitIndex] = 1; 436 node.feature_Type=("离散"); 437 } 438 else 439 { 440 used[info.splitIndex] = 0; 441 node.feature_Type=("连续"); 442 } 443 int sumLeaf = 0; 444 int sumWrong = 0; 445 List<int>[] rowIndex = info.temp; 446 List<String> features = info.features; 447 for (int j = 0; j < rowIndex.Length; j++) 448 { 449 if (rowIndex[j].Count == 0) 450 { 451 continue; 452 } 453 if (info.type == 0) 454 features.Add(""+(j+1)); 455 Node node1 = new Node(); 456 //node1.setNum(info.getTemp()[j]); 457 node1.setClassCount(info.class_Count[j]); 458 //node1.setUsed(used); 459 node1.deep=(deep + 1); 460 node1.rowCount = info.temp[j].Count; 461 node1 = findBestSplit(node1, info.splitIndex,info.temp[j], used); 462 childNode.Add(node1); 463 } 464 node.features=(features); 465 node.childNodes=(childNode); 466 467 #endregion 468 return node; 469 } 470 catch (Exception e) 471 { 472 Console.WriteLine(e.StackTrace); 473 return node; 474 } 475 } 476 /// <summary> 477 /// 计算熵 478 /// </summary> 479 /// <param name="counts"></param> 480 /// <param name="countAll"></param> 481 /// <returns></returns> 482 public static double CalEntropy(double[] counts, int countAll) 483 { 484 try 485 { 486 double allShang = 0; 487 for (int i = 0; i < counts.Length; i++) 488 { 489 if (counts[i] == 0) 490 { 491 continue; 492 } 493 double rate = counts[i] / countAll; 494 allShang = allShang + rate * Math.Log(rate, 2); 495 } 496 return -allShang; 497 } 498 catch (Exception e) 499 { 500 return 0; 501 } 502 } 503 #endregion
(注:上述代码只是ID3的核心代码,数据预处理的代码并没有给出,只要将预处理后的数据输入到主方法findBestSplit中,就可以得到最终的结果)
总结
ID3是基本的决策树构建算法,作为决策树经典的构建算法,其具有结构简单、清晰易懂的特点。虽然ID3比较灵活方便,但是有以下几个缺点:
(1)采用信息增益进行分裂,分裂的精确度可能没有采用信息增益率进行分裂高
(2)不能处理连续型数据,只能通过离散化将连续性数据转化为离散型数据
(3)不能处理缺省值
(4)没有对决策树进行剪枝处理,很可能会出现过拟合的问题