决策树系列(五)——CART

CART,又名分类回归树,是在ID3的基础上进行优化的决策树,学习CART记住以下几个关键点:

(1)CART既能是分类树,又能是分类树;

(2)当CART是分类树时,采用GINI值作为节点分裂的依据;当CART是回归树时,采用样本的最小方差作为节点分裂的依据;

(3)CART是一棵二叉树。

接下来将以一个实际的例子对CART进行介绍:

                                                                    表1 原始数据表

看电视时间

婚姻情况

职业

年龄

3

未婚

学生

12

4

未婚

学生

18

2

已婚

老师

26

5

已婚

上班族

47

2.5

已婚

上班族

36

3.5

未婚

老师

29

4

已婚

学生

21

从以下的思路理解CART

分类树?回归树?

      分类树的作用是通过一个对象的特征来预测该对象所属的类别,而回归树的目的是根据一个对象的信息预测该对象的属性,并以数值表示。

      CART既能是分类树,又能是决策树,如上表所示,如果我们想预测一个人是否已婚,那么构建的CART将是分类树;如果想预测一个人的年龄,那么构建的将是回归树。

分类树和回归树是怎么做决策的?假设我们构建了两棵决策树分别预测用户是否已婚和实际的年龄,如图1和图2所示:

决策树系列(五)——CART_第1张图片

                                      图1 预测婚姻情况决策树                                               图2 预测年龄的决策树

       图1表示一棵分类树,其叶子节点的输出结果为一个实际的类别,在这个例子里是婚姻的情况(已婚或者未婚),选择叶子节点中数量占比最大的类别作为输出的类别;

       图2是一棵回归树,预测用户的实际年龄,是一个具体的输出值。怎样得到这个输出值?一般情况下选择使用中值、平均值或者众数进行表示,图2使用节点年龄数据的平均值作为输出值。

CART如何选择分裂的属性?

      分裂的目的是为了能够让数据变纯,使决策树输出的结果更接近真实值。那么CART是如何评价节点的纯度呢?如果是分类树,CART采用GINI值衡量节点纯度;如果是回归树,采用样本方差衡量节点纯度。节点越不纯,节点分类或者预测的效果就越差。

GINI值的计算公式:

                               

      节点越不纯,GINI值越大。以二分类为例,如果节点的所有数据只有一个类别,则 ,如果两类数量相同,则

回归方差计算公式:

                                                                               

      方差越大,表示该节点的数据越分散,预测的效果就越差。如果一个节点的所有数据都相同,那么方差就为0,此时可以很肯定得认为该节点的输出值;如果节点的数据相差很大,那么输出的值有很大的可能与实际值相差较大。

      因此,无论是分类树还是回归树,CART都要选择使子节点的GINI值或者回归方差最小的属性作为分裂的方案。即最小化(分类树):

                               

或者(回归树):

                                                                                                     

CART如何分裂成一棵二叉树?

     节点的分裂分为两种情况,连续型的数据和离散型的数据。

     CART对连续型属性的处理与C4.5差不多,通过最小化分裂后的GINI值或者样本方差寻找最优分割点,将节点一分为二,在这里不再叙述,详细请看C4.5

     对于离散型属性,理论上有多少个离散值就应该分裂成多少个节点。但CART是一棵二叉树,每一次分裂只会产生两个节点,怎么办呢?很简单,只要将其中一个离散值独立作为一个节点,其他的离散值生成另外一个节点即可。这种分裂方案有多少个离散值就有多少种划分的方法,举一个简单的例子:如果某离散属性一个有三个离散值X,Y,Z,则该属性的分裂方法有{X}、{Y,Z},{Y}、{X,Z},{Z}、{X,Y},分别计算每种划分方法的基尼值或者样本方差确定最优的方法。

     以属性“职业”为例,一共有三个离散值,“学生”、“老师”、“上班族”。该属性有三种划分的方案,分别为{“学生”}、{“老师”、“上班族”},{“老师”}、{“学生”、“上班族”},{“上班族”}、{“学生”、“老师”},分别计算三种划分方案的子节点GINI值或者样本方差,选择最优的划分方法,如下图所示:

第一种划分方法:{“学生”}、{“老师”、“上班族”}

决策树系列(五)——CART_第2张图片

预测是否已婚(分类):

                    

预测年龄(回归):

            

 

第二种划分方法:{“老师”}、{“学生”、“上班族”}

 决策树系列(五)——CART_第3张图片

预测是否已婚(分类):

                    

预测年龄(回归):

            

第三种划分方法:{“上班族”}、{“学生”、“老师”}

决策树系列(五)——CART_第4张图片

 预测是否已婚(分类):

                    

预测年龄(回归):

            

综上,如果想预测是否已婚,则选择{“上班族”}、{“学生”、“老师”}的划分方法,如果想预测年龄,则选择{“老师”}、{“学生”、“上班族”}的划分方法。

 

如何剪枝?

      CART采用CCP(代价复杂度)剪枝方法。代价复杂度选择节点表面误差率增益值最小的非叶子节点,删除该非叶子节点的左右子节点,若有多个非叶子节点的表面误差率增益值相同小,则选择非叶子节点中子节点数最多的非叶子节点进行剪枝。

可描述如下:

令决策树的非叶子节点为

a)计算所有非叶子节点的表面误差率增益值 

b)选择表面误差率增益值最小的非叶子节点(若多个非叶子节点具有相同小的表面误差率增益值,选择节点数最多的非叶子节点)。

c)对进行剪枝

表面误差率增益值的计算公式:

                               

其中:

表示叶子节点的误差代价, 为节点的错误率, 为节点数据量的占比;

表示子树的误差代价,为子节点i的错误率, 表示节点i的数据节点占比;

表示子树节点个数。

算例:

下图是其中一颗子树,设决策树的总数据量为40。

决策树系列(五)——CART_第5张图片

该子树的表面误差率增益值可以计算如下:

 决策树系列(五)——CART_第6张图片

求出该子树的表面错误覆盖率为 ,只要求出其他子树的表面误差率增益值就可以对决策树进行剪枝。

 

程序实际以及源代码

流程图:

决策树系列(五)——CART_第7张图片

(1)数据处理

         对原始的数据进行数字化处理,并以二维数据的形式存储,每一行表示一条记录,前n-1列表示属性,最后一列表示分类的标签。

         如表1的数据可以转化为表2:

                                                                           表2 初始化后的数据

看电视时间

婚姻情况

职业

年龄

3

未婚

学生

12

4

未婚

学生

18

2

已婚

老师

26

5

已婚

上班族

47

2.5

已婚

上班族

36

3.5

未婚

老师

29

4

已婚

学生

21

        

      其中,对于“婚姻情况”属性,数字{1,2}分别表示{未婚,已婚 };对于“职业”属性{1,2,3, }分别表示{学生、老师、上班族};

代码如下所示:

         static double[][] allData;                              //存储进行训练的数据

    static List<String>[] featureValues;                    //离散属性对应的离散值

featureValues是链表数组,数组的长度为属性的个数,数组的每个元素为该属性的离散值链表。

(2)两个类:节点类和分裂信息

a)节点类Node

      该类表示一个节点,属性包括节点选择的分裂属性、节点的输出类、孩子节点、深度等。注意,与ID3中相比,新增了两个属性:leafWrong和leafNode_Count分别表示叶子节点的总分类误差和叶子节点的个数,主要是为了方便剪枝。

 1 class Node
 2 {
 3     /// <summary>
 4     /// 每一个节点的分裂值
 5     /// </summary>
 6     public List<String> features { get; set; }
 7     /// <summary>
 8     /// 分裂属性的类型{离散、连续}
 9     /// </summary>
10     public String feature_Type { get; set; }
11     /// <summary>
12     /// 分裂属性的下标
13     /// </summary>
14     public String SplitFeature { get; set; }
15     //List<int> nums = new List<int>();                       //行序号
16     /// <summary>
17     /// 每一个类对应的数目
18     /// </summary>
19     public double[] ClassCount { get; set; }
20     //int[] isUsed = new int[0];                              //属性的使用情况 1:已用 2:未用
21     /// <summary>
22     /// 孩子节点
23     /// </summary>
24     public List<Node> childNodes { get; set; }
25     Node Parent = null;
26     /// <summary>
27     /// 该节点占比最大的类别
28     /// </summary>
29     public String finalResult { get; set; }
30     /// <summary>
31     /// 树的深度
32     /// </summary>
33     public int deep { get; set; }
34     /// <summary>
35     /// 最大的类下标
36     /// </summary>
37     public int result { get; set; }
38     /// <summary>
39     /// 子节点误差
40     /// </summary>
41     public int leafWrong { get; set; }
42     /// <summary>
43     /// 子节点数目
44     /// </summary>
45     public int leafNode_Count { get; set; }
46     /// <summary>
47     /// 数据量
48     /// </summary>
49     public int rowCount { get; set; }
50 
51     public void setClassCount(double[] count)
52     {
53         this.ClassCount = count;
54         double max = ClassCount[0];
55         int result = 0;
56         for (int i = 1; i < ClassCount.Length; i++)
57         {
58             if (max < ClassCount[i])
59             {
60                 max = ClassCount[i];
61                 result = i;
62             }
63         }
64         this.result = result;
65     }
66     public double getErrorCount()
67     {
68         return rowCount - ClassCount[result];
69     }
70 }
树的节点

b)分裂信息类,该类存储节点进行分裂的信息,包括各个子节点的行坐标、子节点各个类的数目、该节点分裂的属性、属性的类型等。

 1     class SplitInfo
 2     {
 3         /// <summary>
 4         /// 分裂的属性下标
 5         /// </summary>
 6         public int splitIndex { get; set; }
 7         /// <summary>
 8         /// 数据类型
 9         /// </summary>
10         public int type { get; set; }
11         /// <summary>
12         /// 分裂属性的取值
13         /// </summary>
14         public List<String> features { get; set; }
15         /// <summary>
16         /// 各个节点的行坐标链表
17         /// </summary>
18         public List<int>[] temp { get; set; }
19         /// <summary>
20         /// 每个节点各类的数目
21         /// </summary>
22         public double[][] class_Count { get; set; }
23     }
分裂信息

主方法findBestSplit(Node node,List<int> nums,int[] isUsed),该方法对节点进行分裂

其中:

node表示即将进行分裂的节点;

nums表示节点数据对一个的行坐标列表;

isUsed表示到该节点位置所有属性的使用情况;

findBestSplit的这个方法主要有以下几个组成部分:

1)节点分裂停止的判定

节点分裂条件如上文所述,源代码如下:

  1         public static bool ifEnd(Node node, double shang,int[] isUsed)
  2         {
  3             try
  4             {
  5                 double[] count = node.ClassCount;
  6                 int rowCount = node.rowCount;
  7                 int maxResult = 0;
  8                 double maxRate = 0;
  9                 #region 数达到某一深度
 10                 int deep = node.deep;
 11                 if (deep >= 10)
 12                 {
 13                     maxResult = node.result + 1;
 14                     node.feature_Type="result";
 15                     node.features=new List<String>() { maxResult + "" 
 16 
 17 };
 18                     node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
 19                     node.leafNode_Count=1;
 20                     return true;
 21                 }
 22                 #endregion
 23                 #region 纯度(其实跟后面的有点重了,记得要修改)
 24                 //maxResult = 1;
 25                 //for (int i = 1; i < count.Length; i++)
 26                 //{
 27                 //    if (count[i] / rowCount >= 0.95)
 28                 //    {
 29                 //        node.feature_Type="result";
 30                 //        node.features=new List<String> { "" + (i + 
 31 
 32 1) };
 33                 //        node.leafNode_Count=1;
 34                 //        node.leafWrong=rowCount - Convert.ToInt32
 35 
 36 (count[i]);
 37                 //        return true;
 38                 //    }
 39                 //}
 40                 #endregion
 41                 #region 熵为0
 42                 if (shang == 0)
 43                 {
 44                     maxRate = count[0] / rowCount;
 45                     maxResult = 1;
 46                     for (int i = 1; i < count.Length; i++)
 47                     {
 48                         if (count[i] / rowCount >= maxRate)
 49                         {
 50                             maxRate = count[i] / rowCount;
 51                             maxResult = i + 1;
 52                         }
 53                     }
 54                     node.feature_Type="result";
 55                     node.features=new List<String> { maxResult + "" 
 56 
 57 };
 58                     node.leafWrong=rowCount - Convert.ToInt32(count
 59 
 60 [maxResult - 1]);
 61                     node.leafNode_Count=1;
 62                     return true;
 63                 }
 64                 #endregion
 65                 #region 属性已经分完
 66                 //int[] isUsed = node.getUsed();
 67                 bool flag = true;
 68                 for (int i = 0; i < isUsed.Length - 1; i++)
 69                 {
 70                     if (isUsed[i] == 0)
 71                     {
 72                         flag = false;
 73                         break;
 74                     }
 75                 }
 76                 if (flag)
 77                 {
 78                     maxRate = count[0] / rowCount;
 79                     maxResult = 1;
 80                     for (int i = 1; i < count.Length; i++)
 81                     {
 82                         if (count[i] / rowCount >= maxRate)
 83                         {
 84                             maxRate = count[i] / rowCount;
 85                             maxResult = i + 1;
 86                         }
 87                     }
 88                     node.feature_Type=("result");
 89                     node.features=(new List<String> { "" + 
 90 
 91 (maxResult) });
 92                     node.leafWrong=(rowCount - Convert.ToInt32(count
 93 
 94 [maxResult - 1]));
 95                     node.leafNode_Count=(1);
 96                     return true;
 97                 }
 98                 #endregion
 99                 #region 几点数少于100
100                 if (rowCount < Limit_Node)
101                 {
102                     maxRate = count[0] / rowCount;
103                     maxResult = 1;
104                     for (int i = 1; i < count.Length; i++)
105                     {
106                         if (count[i] / rowCount >= maxRate)
107                         {
108                             maxRate = count[i] / rowCount;
109                             maxResult = i + 1;
110                         }
111                     }
112                     node.feature_Type="result";
113                     node.features=new List<String> { "" + (maxResult) 
114 
115 };
116                     node.leafWrong=rowCount - Convert.ToInt32(count
117 
118 [maxResult - 1]);
119                     node.leafNode_Count=1;
120                     return true;
121                 }
122                 #endregion
123                 return false;
124             }
125             catch (Exception e)
126             {
127                 return false;
128             }
129         }
停止分裂的条件

2)寻找最优的分裂属性

寻找最优的分裂属性需要计算每一个分裂属性分裂后的GINI值或者样本方差,计算公式上文已给出,其中GINI值的计算代码如下:

1         public static double getGini(double[] counts, int countAll)
2         {
3             double Gini = 1;
4             for (int i = 0; i < counts.Length; i++)
5             {
6                 Gini = Gini - Math.Pow(counts[i] / countAll, 2);
7             }
8             return Gini;
9         }
GINI值计算

3)进行分裂,同时对子节点进行迭代处理

其实就是递归的过程,对每一个子节点执行findBestSplit方法进行分裂。

findBestSplit源代码:

  1         public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
  2         {
  3             try
  4             {
  5                 //判断是否继续分裂
  6                 double totalShang = getGini(node.ClassCount, node.rowCount);
  7                 if (ifEnd(node, totalShang, isUsed))
  8                 {
  9                     return node;
 10                 }
 11                 #region 变量声明
 12                 SplitInfo info = new SplitInfo();
 13                 info.initial();
 14                 int RowCount = nums.Count;                  //样本总数
 15                 double jubuMax = 1;                         //局部最大熵
 16                 int splitPoint = 0;                         //分裂的点
 17                 double splitValue = 0;                      //分裂的值
 18                 #endregion
 19                 for (int i = 0; i < isUsed.Length - 1; i++)
 20                 {
 21                     if (isUsed[i] == 1)
 22                     {
 23                         continue;
 24                     }
 25                     #region 离散变量
 26                     if (type[i] == 0)
 27                     {
 28                         double[][] allCount = new double[allNum[i]][];
 29                         for (int j = 0; j < allCount.Length; j++)
 30                         {
 31                             allCount[j] = new double[classCount];
 32                         }
 33                         int[] countAllFeature = new int[allNum[i]];
 34                         List<int>[] temp = new List<int>[allNum[i]];
 35                         double[] allClassCount = node.ClassCount;     //所有类别的数量
 36                         for (int j = 0; j < temp.Length; j++)
 37                         {
 38                             temp[j] = new List<int>();
 39                         }
 40                         for (int j = 0; j < nums.Count; j++)
 41                         {
 42                             int index = Convert.ToInt32(allData[nums[j]][i]);
 43                             temp[index - 1].Add(nums[j]);
 44                             countAllFeature[index - 1]++;
 45                             allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
 46                         }
 47                         double allShang = 1;
 48                         int choose = 0;
 49 
 50                         double[][] jubuCount = new double[2][];
 51                         for (int k = 0; k < allCount.Length; k++)
 52                         {
 53                             if (temp[k].Count == 0)
 54                                 continue;
 55                             double JubuShang = 0;
 56                             double[][] tempCount = new double[2][];
 57                             tempCount[0] = allCount[k];
 58                             tempCount[1] = new double[allCount[0].Length];
 59                             for (int j = 0; j < tempCount[1].Length; j++)
 60                             {
 61                                 tempCount[1][j] = allClassCount[j] - allCount[k][j];
 62                             }
 63                             JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
 64                             int nodecount = RowCount - countAllFeature[k];
 65                             JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
 66                             if (JubuShang < allShang)
 67                             {
 68                                 allShang = JubuShang;
 69                                 jubuCount = tempCount;
 70                                 choose = k;
 71                             }
 72                         }                        
 73                         if (allShang < jubuMax)
 74                         {
 75                             info.type = 0;
 76                             jubuMax = allShang;
 77                             info.class_Count = jubuCount;
 78                             info.temp[0] = temp[choose];
 79                             info.temp[1] = new List<int>();
 80                             info.features = new List<string>();
 81                             info.features.Add((choose + 1) + "");
 82                             info.features.Add("");
 83                             for (int j = 0; j < temp.Length; j++)
 84                             {
 85                                 if (j == choose)
 86                                     continue;
 87                                 for (int k = 0; k < temp[j].Count; k++)
 88                                 {
 89                                     info.temp[1].Add(temp[j][k]);
 90                                 }
 91                                 if (temp[j].Count != 0)
 92                                 {
 93                                     info.features[1] = info.features[1] + (j + 1) + ",";
 94                                 }
 95                             }
 96                             info.splitIndex = i;
 97                         }
 98                     }
 99                     #endregion
100                     #region 连续变量
101                     else
102                     {
103                         double[] leftCunt = new double[classCount];   
104 
105           //做节点各个类别的数量
106                         double[] rightCount = new double[classCount]; 
107 
108           //右节点各个类别的数量
109                         double[] count1 = new double[classCount];     
110 
111           //子集1的统计量
112                         double[] count2 = new double
113 
114 [node.ClassCount.Length];   //子集2的统计量
115                         for (int j = 0; j < node.ClassCount.Length; 
116 
117 j++)
118                         {
119                             count2[j] = node.ClassCount[j];
120                         }
121                         int all1 = 0;                                 
122 
123           //子集1的样本量
124                         int all2 = nums.Count;                        
125 
126           //子集2的样本量
127                         double lastValue = 0;                         
128 
129          //上一个记录的类别
130                         double currentValue = 0;                      
131 
132          //当前类别
133                         double lastPoint = 0;                         
134 
135           //上一个点的值
136                         double currentPoint = 0;                      
137 
138           //当前点的值
139                         double[] values = new double[nums.Count];
140                         for (int j = 0; j < values.Length; j++)
141                         {
142                             values[j] = allData[nums[j]][i];
143                         }
144                         QSort(values, nums, 0, nums.Count - 1);
145                         double lianxuMax = 1;                         
146 
147           //连续型属性的最大熵
148                         #region 寻找最佳的分割点
149                         for (int j = 0; j < nums.Count - 1; j++)
150                         {
151                             currentValue = allData[nums[j]][lieshu - 
152 
153 1];
154                             currentPoint = (allData[nums[j]][i]);
155                             if (j == 0)
156                             {
157                                 lastValue = currentValue;
158                                 lastPoint = currentPoint;
159                             }
160                             if (currentValue != lastValue && 
161 
162 currentPoint != lastPoint)
163                             {
164                                 double shang1 = getGini(count1, 
165 
166 all1);
167                                 double shang2 = getGini(count2, 
168 
169 all2);
170                                 double allShang = shang1 * all1 / 
171 
172 (all1 + all2) + shang2 * all2 / (all1 + all2);
173                                 //allShang = (totalShang - allShang);
174                                 if (lianxuMax > allShang)
175                                 {
176                                     lianxuMax = allShang;
177                                     for (int k = 0; k < 
178 
179 count1.Length; k++)
180                                     {
181                                         leftCunt[k] = count1[k];
182                                         rightCount[k] = count2[k];
183                                     }
184                                     splitPoint = j;
185                                     splitValue = (currentPoint + 
186 
187 lastPoint) / 2;
188                                 }
189                             }
190                             all1++;
191                             count1[Convert.ToInt32(currentValue) - 
192 
193 1]++;
194                             count2[Convert.ToInt32(currentValue) - 
195 
196 1]--;
197                             all2--;
198                             lastValue = currentValue;
199                             lastPoint = currentPoint;
200                         }
201                         #endregion
202                         #region 如果超过了局部值,重设
203                         if (lianxuMax < jubuMax)
204                         {
205                             info.type = 1;
206                             info.splitIndex = i;
207                             info.features=new List<string>()
208 
209 {splitValue+""};
210                             //finalPoint = splitPoint;
211                             jubuMax = lianxuMax;
212                             info.temp[0] = new List<int>();
213                             info.temp[1] = new List<int>();
214                             for (int k = 0; k < splitPoint; k++)
215                             {
216                                 info.temp[0].Add(nums[k]);
217                             }
218                             for (int k = splitPoint; k < nums.Count; 
219 
220 k++)
221                             {
222                                 info.temp[1].Add(nums[k]);
223                             }
224                             info.class_Count[0] = new double
225 
226 [leftCunt.Length];
227                             info.class_Count[1] = new double
228 
229 [leftCunt.Length];
230                             for (int k = 0; k < leftCunt.Length; k++)
231                             {
232                                 info.class_Count[0][k] = leftCunt[k];
233                                 info.class_Count[1][k] = rightCount
234 
235 [k];
236                             }
237                         }
238                         #endregion
239                     }
240                     #endregion
241                 }
242                 #region 没有寻找到最佳的分裂点,则设置为叶节点
243                 if (info.splitIndex == -1)
244                 {
245                     double[] finalCount = node.ClassCount;
246                     double max = finalCount[0];
247                     int result = 1;
248                     for (int i = 1; i < finalCount.Length; i++)
249                     {
250                         if (finalCount[i] > max)
251                         {
252                             max = finalCount[i];
253                             result = (i + 1);
254                         }
255                     }
256                     node.feature_Type="result";
257                     node.features=new List<String> { "" + result };
258                     return node;
259                 }
260                 #endregion
261                 #region 分裂
262                 int deep = node.deep;
263                 node.SplitFeature = ("" + info.splitIndex);
264                 List<Node> childNode = new List<Node>();
265                 int[][] used = new int[2][];
266                 used[0] = new int[isUsed.Length];
267                 used[1] = new int[isUsed.Length];
268                 for (int i = 0; i < isUsed.Length; i++)
269                 {
270                     used[0][i] = isUsed[i];
271                     used[1][i] = isUsed[i];
272                 }
273                 if (info.type == 0)
274                 {
275                     used[0][info.splitIndex] = 1;
276                     node.feature_Type = ("离散");
277                 }
278                 else
279                 {
280                     //used[info.splitIndex] = 0;
281                     node.feature_Type = ("连续");
282                 }
283                 List<int>[] rowIndex = info.temp;
284                 List<String> features = info.features;
285                 Node node1 = new Node();
286                 Node node2 = new Node();
287                 node1.setClassCount(info.class_Count[0]);
288                 node2.setClassCount(info.class_Count[1]);
289                 node1.rowCount = info.temp[0].Count;
290                 node2.rowCount = info.temp[1].Count;
291                 node1.deep = deep + 1;
292                 node2.deep = deep + 1;
293                 node1 = findBestSplit(node1, info.temp[0],used[0]);
294                 node2 = findBestSplit(node2, info.temp[1], used[1]);
295                 node.leafNode_Count = (node1.leafNode_Count
296 
297 +node2.leafNode_Count);
298                 node.leafWrong = (node1.leafWrong+node2.leafWrong);
299                 node.features = (features);
300                 childNode.Add(node1);
301                 childNode.Add(node2);
302                 node.childNodes = childNode;
303                 #endregion
304                 return node;
305             }
306             catch (Exception e)
307             {
308                 Console.WriteLine(e.StackTrace);
309                 return node;
310             }
311         }
节点选择属性和分裂

(4)剪枝

代价复杂度剪枝方法(CCP):

 1         public static void getSeries(Node node)
 2         {
 3             Stack<Node> nodeStack = new Stack<Node>();
 4             if (node != null)
 5             {
 6                 nodeStack.Push(node);
 7             }
 8             if (node.feature_Type == "result")
 9                 return;
10             List<Node> childs = node.childNodes;
11             for (int i = 0; i < childs.Count; i++)
12             {
13                 getSeries(node);
14             }
15         }
CCP代价复杂度剪枝

CART全部核心代码:

  1         /// <summary>
  2         /// 判断是否还需要分裂
  3         /// </summary>
  4         /// <param name="node"></param>
  5         /// <returns></returns>
  6         public static bool ifEnd(Node node, double shang,int[] isUsed)
  7         {
  8             try
  9             {
 10                 double[] count = node.ClassCount;
 11                 int rowCount = node.rowCount;
 12                 int maxResult = 0;
 13                 double maxRate = 0;
 14                 #region 数达到某一深度
 15                 int deep = node.deep;
 16                 if (deep >= 10)
 17                 {
 18                     maxResult = node.result + 1;
 19                     node.feature_Type="result";
 20                     node.features=new List<String>() { maxResult + "" 
 21 
 22 };
 23                     node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
 24                     node.leafNode_Count=1;
 25                     return true;
 26                 }
 27                 #endregion
 28                 #region 纯度(其实跟后面的有点重了,记得要修改)
 29                 //maxResult = 1;
 30                 //for (int i = 1; i < count.Length; i++)
 31                 //{
 32                 //    if (count[i] / rowCount >= 0.95)
 33                 //    {
 34                 //        node.feature_Type="result";
 35                 //        node.features=new List<String> { "" + (i + 
 36 
 37 1) };
 38                 //        node.leafNode_Count=1;
 39                 //        node.leafWrong=rowCount - Convert.ToInt32
 40 
 41 (count[i]);
 42                 //        return true;
 43                 //    }
 44                 //}
 45                 #endregion
 46                 #region 熵为0
 47                 if (shang == 0)
 48                 {
 49                     maxRate = count[0] / rowCount;
 50                     maxResult = 1;
 51                     for (int i = 1; i < count.Length; i++)
 52                     {
 53                         if (count[i] / rowCount >= maxRate)
 54                         {
 55                             maxRate = count[i] / rowCount;
 56                             maxResult = i + 1;
 57                         }
 58                     }
 59                     node.feature_Type="result";
 60                     node.features=new List<String> { maxResult + "" 
 61 
 62 };
 63                     node.leafWrong=rowCount - Convert.ToInt32(count
 64 
 65 [maxResult - 1]);
 66                     node.leafNode_Count=1;
 67                     return true;
 68                 }
 69                 #endregion
 70                 #region 属性已经分完
 71                 //int[] isUsed = node.getUsed();
 72                 bool flag = true;
 73                 for (int i = 0; i < isUsed.Length - 1; i++)
 74                 {
 75                     if (isUsed[i] == 0)
 76                     {
 77                         flag = false;
 78                         break;
 79                     }
 80                 }
 81                 if (flag)
 82                 {
 83                     maxRate = count[0] / rowCount;
 84                     maxResult = 1;
 85                     for (int i = 1; i < count.Length; i++)
 86                     {
 87                         if (count[i] / rowCount >= maxRate)
 88                         {
 89                             maxRate = count[i] / rowCount;
 90                             maxResult = i + 1;
 91                         }
 92                     }
 93                     node.feature_Type=("result");
 94                     node.features=(new List<String> { "" + 
 95 
 96 (maxResult) });
 97                     node.leafWrong=(rowCount - Convert.ToInt32(count
 98 
 99 [maxResult - 1]));
100                     node.leafNode_Count=(1);
101                     return true;
102                 }
103                 #endregion
104                 #region 几点数少于100
105                 if (rowCount < Limit_Node)
106                 {
107                     maxRate = count[0] / rowCount;
108                     maxResult = 1;
109                     for (int i = 1; i < count.Length; i++)
110                     {
111                         if (count[i] / rowCount >= maxRate)
112                         {
113                             maxRate = count[i] / rowCount;
114                             maxResult = i + 1;
115                         }
116                     }
117                     node.feature_Type="result";
118                     node.features=new List<String> { "" + (maxResult) 
119 
120 };
121                     node.leafWrong=rowCount - Convert.ToInt32(count
122 
123 [maxResult - 1]);
124                     node.leafNode_Count=1;
125                     return true;
126                 }
127                 #endregion
128                 return false;
129             }
130             catch (Exception e)
131             {
132                 return false;
133             }
134         }
135         #region 排序算法
136         public static void InsertSort(double[] values, List<int> arr, 
137 
138 int StartIndex, int endIndex)
139         {
140             for (int i = StartIndex + 1; i <= endIndex; i++)
141             {
142                 int key = arr[i];
143                 double init = values[i];
144                 int j = i - 1;
145                 while (j >= StartIndex && values[j] > init)
146                 {
147                     arr[j + 1] = arr[j];
148                     values[j + 1] = values[j];
149                     j--;
150                 }
151                 arr[j + 1] = key;
152                 values[j + 1] = init;
153             }
154         }
155         static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
156         {
157             int mid = low + ((high - low) >> 1);//计算数组中间的元素的下标  
158 
159             //使用三数取中法选择枢轴  
160             if (values[mid] > values[high])//目标: arr[mid] <= arr[high]  
161             {
162                 swap(values, arr, mid, high);
163             }
164             if (values[low] > values[high])//目标: arr[low] <= arr[high]  
165             {
166                 swap(values, arr, low, high);
167             }
168             if (values[mid] > values[low]) //目标: arr[low] >= arr[mid]  
169             {
170                 swap(values, arr, mid, low);
171             }
172             //此时,arr[mid] <= arr[low] <= arr[high]  
173             return low;
174             //low的位置上保存这三个位置中间的值  
175             //分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了  
176         }
177         static void swap(double[] values, List<int> arr, int t1, int t2)
178         {
179             double temp = values[t1];
180             values[t1] = values[t2];
181             values[t2] = temp;
182             int key = arr[t1];
183             arr[t1] = arr[t2];
184             arr[t2] = key;
185         }
186         static void QSort(double[] values, List<int> arr, int low, int high)
187         {
188             int first = low;
189             int last = high;
190 
191             int left = low;
192             int right = high;
193 
194             int leftLen = 0;
195             int rightLen = 0;
196 
197             if (high - low + 1 < 10)
198             {
199                 InsertSort(values, arr, low, high);
200                 return;
201             }
202 
203             //一次分割 
204             int key = SelectPivotMedianOfThree(values, arr, low, 
205 
206 high);//使用三数取中法选择枢轴 
207             double inti = values[key];
208             int currentKey = arr[key];
209 
210             while (low < high)
211             {
212                 while (high > low && values[high] >= inti)
213                 {
214                     if (values[high] == inti)//处理相等元素  
215                     {
216                         swap(values, arr, right, high);
217                         right--;
218                         rightLen++;
219                     }
220                     high--;
221                 }
222                 arr[low] = arr[high];
223                 values[low] = values[high];
224                 while (high > low && values[low] <= inti)
225                 {
226                     if (values[low] == inti)
227                     {
228                         swap(values, arr, left, low);
229                         left++;
230                         leftLen++;
231                     }
232                     low++;
233                 }
234                 arr[high] = arr[low];
235                 values[high] = values[low];
236             }
237             arr[low] = currentKey;
238             values[low] = values[key];
239             //一次快排结束  
240             //把与枢轴key相同的元素移到枢轴最终位置周围  
241             int i = low - 1;
242             int j = first;
243             while (j < left && values[i] != inti)
244             {
245                 swap(values, arr, i, j);
246                 i--;
247                 j++;
248             }
249             i = low + 1;
250             j = last;
251             while (j > right && values[i] != inti)
252             {
253                 swap(values, arr, i, j);
254                 i++;
255                 j--;
256             }
257             QSort(values, arr, first, low - 1 - leftLen);
258             QSort(values, arr, low + 1 + rightLen, last);
259         }
260         #endregion
261         /// <summary>
262         /// 寻找最佳的分裂点
263         /// </summary>
264         /// <param name="num"></param>
265         /// <param name="node"></param>
266         public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
267         {
268             try
269             {
270                 //判断是否继续分裂
271                 double totalShang = getGini(node.ClassCount, node.rowCount);
272                 if (ifEnd(node, totalShang, isUsed))
273                 {
274                     return node;
275                 }
276                 #region 变量声明
277                 SplitInfo info = new SplitInfo();
278                 info.initial();
279                 int RowCount = nums.Count;                  //样本总数
280                 double jubuMax = 1;                         //局部最大熵
281                 int splitPoint = 0;                         //分裂的点
282                 double splitValue = 0;                      //分裂的值
283                 #endregion
284                 for (int i = 0; i < isUsed.Length - 1; i++)
285                 {
286                     if (isUsed[i] == 1)
287                     {
288                         continue;
289                     }
290                     #region 离散变量
291                     if (type[i] == 0)
292                     {
293                         double[][] allCount = new double[allNum[i]][];
294                         for (int j = 0; j < allCount.Length; j++)
295                         {
296                             allCount[j] = new double[classCount];
297                         }
298                         int[] countAllFeature = new int[allNum[i]];
299                         List<int>[] temp = new List<int>[allNum[i]];
300                         double[] allClassCount = node.ClassCount;     //所有类别的数量
301                         for (int j = 0; j < temp.Length; j++)
302                         {
303                             temp[j] = new List<int>();
304                         }
305                         for (int j = 0; j < nums.Count; j++)
306                         {
307                             int index = Convert.ToInt32(allData[nums[j]][i]);
308                             temp[index - 1].Add(nums[j]);
309                             countAllFeature[index - 1]++;
310                             allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
311                         }
312                         double allShang = 1;
313                         int choose = 0;
314 
315                         double[][] jubuCount = new double[2][];
316                         for (int k = 0; k < allCount.Length; k++)
317                         {
318                             if (temp[k].Count == 0)
319                                 continue;
320                             double JubuShang = 0;
321                             double[][] tempCount = new double[2][];
322                             tempCount[0] = allCount[k];
323                             tempCount[1] = new double[allCount[0].Length];
324                             for (int j = 0; j < tempCount[1].Length; j++)
325                             {
326                                 tempCount[1][j] = allClassCount[j] - allCount[k][j];
327                             }
328                             JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
329                             int nodecount = RowCount - countAllFeature[k];
330                             JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
331                             if (JubuShang < allShang)
332                             {
333                                 allShang = JubuShang;
334                                 jubuCount = tempCount;
335                                 choose = k;
336                             }
337                         }                        
338                         if (allShang < jubuMax)
339                         {
340                             info.type = 0;
341                             jubuMax = allShang;
342                             info.class_Count = jubuCount;
343                             info.temp[0] = temp[choose];
344                             info.temp[1] = new List<int>();
345                             info.features = new List<string>();
346                             info.features.Add((choose + 1) + "");
347                             info.features.Add("");
348                             for (int j = 0; j < temp.Length; j++)
349                             {
350                                 if (j == choose)
351                                     continue;
352                                 for (int k = 0; k < temp[j].Count; k++)
353                                 {
354                                     info.temp[1].Add(temp[j][k]);
355                                 }
356                                 if (temp[j].Count != 0)
357                                 {
358                                     info.features[1] = info.features[1] + (j + 1) + ",";
359                                 }
360                             }
361                             info.splitIndex = i;
362                         }
363                     }
364                     #endregion
365                     #region 连续变量
366                     else
367                     {
368                         double[] leftCunt = new double[classCount];   
369 
370           //做节点各个类别的数量
371                         double[] rightCount = new double[classCount]; 
372 
373           //右节点各个类别的数量
374                         double[] count1 = new double[classCount];     
375 
376           //子集1的统计量
377                         double[] count2 = new double
378 
379 [node.ClassCount.Length];   //子集2的统计量
380                         for (int j = 0; j < node.ClassCount.Length; 
381 
382 j++)
383                         {
384                             count2[j] = node.ClassCount[j];
385                         }
386                         int all1 = 0;                                 
387 
388           //子集1的样本量
389                         int all2 = nums.Count;                        
390 
391           //子集2的样本量
392                         double lastValue = 0;                         
393 
394          //上一个记录的类别
395                         double currentValue = 0;                      
396 
397          //当前类别
398                         double lastPoint = 0;                         
399 
400           //上一个点的值
401                         double currentPoint = 0;                      
402 
403           //当前点的值
404                         double[] values = new double[nums.Count];
405                         for (int j = 0; j < values.Length; j++)
406                         {
407                             values[j] = allData[nums[j]][i];
408                         }
409                         QSort(values, nums, 0, nums.Count - 1);
410                         double lianxuMax = 1;                         
411 
412           //连续型属性的最大熵
413                         #region 寻找最佳的分割点
414                         for (int j = 0; j < nums.Count - 1; j++)
415                         {
416                             currentValue = allData[nums[j]][lieshu - 
417 
418 1];
419                             currentPoint = (allData[nums[j]][i]);
420                             if (j == 0)
421                             {
422                                 lastValue = currentValue;
423                                 lastPoint = currentPoint;
424                             }
425                             if (currentValue != lastValue && 
426 
427 currentPoint != lastPoint)
428                             {
429                                 double shang1 = getGini(count1, 
430 
431 all1);
432                                 double shang2 = getGini(count2, 
433 
434 all2);
435                                 double allShang = shang1 * all1 / 
436 
437 (all1 + all2) + shang2 * all2 / (all1 + all2);
438                                 //allShang = (totalShang - allShang);
439                                 if (lianxuMax > allShang)
440                                 {
441                                     lianxuMax = allShang;
442                                     for (int k = 0; k < 
443 
444 count1.Length; k++)
445                                     {
446                                         leftCunt[k] = count1[k];
447                                         rightCount[k] = count2[k];
448                                     }
449                                     splitPoint = j;
450                                     splitValue = (currentPoint + 
451 
452 lastPoint) / 2;
453                                 }
454                             }
455                             all1++;
456                             count1[Convert.ToInt32(currentValue) - 
457 
458 1]++;
459                             count2[Convert.ToInt32(currentValue) - 
460 
461 1]--;
462                             all2--;
463                             lastValue = currentValue;
464                             lastPoint = currentPoint;
465                         }
466                         #endregion
467                         #region 如果超过了局部值,重设
468                         if (lianxuMax < jubuMax)
469                         {
470                             info.type = 1;
471                             info.splitIndex = i;
472                             info.features=new List<string>()
473 
474 {splitValue+""};
475                             //finalPoint = splitPoint;
476                             jubuMax = lianxuMax;
477                             info.temp[0] = new List<int>();
478                             info.temp[1] = new List<int>();
479                             for (int k = 0; k < splitPoint; k++)
480                             {
481                                 info.temp[0].Add(nums[k]);
482                             }
483                             for (int k = splitPoint; k < nums.Count; 
484 
485 k++)
486                             {
487                                 info.temp[1].Add(nums[k]);
488                             }
489                             info.class_Count[0] = new double
490 
491 [leftCunt.Length];
492                             info.class_Count[1] = new double
493 
494 [leftCunt.Length];
495                             for (int k = 0; k < leftCunt.Length; k++)
496                             {
497                                 info.class_Count[0][k] = leftCunt[k];
498                                 info.class_Count[1][k] = rightCount
499 
500 [k];
501                             }
502                         }
503                         #endregion
504                     }
505                     #endregion
506                 }
507                 #region 没有寻找到最佳的分裂点,则设置为叶节点
508                 if (info.splitIndex == -1)
509                 {
510                     double[] finalCount = node.ClassCount;
511                     double max = finalCount[0];
512                     int result = 1;
513                     for (int i = 1; i < finalCount.Length; i++)
514                     {
515                         if (finalCount[i] > max)
516                         {
517                             max = finalCount[i];
518                             result = (i + 1);
519                         }
520                     }
521                     node.feature_Type="result";
522                     node.features=new List<String> { "" + result };
523                     return node;
524                 }
525                 #endregion
526                 #region 分裂
527                 int deep = node.deep;
528                 node.SplitFeature = ("" + info.splitIndex);
529                 List<Node> childNode = new List<Node>();
530                 int[][] used = new int[2][];
531                 used[0] = new int[isUsed.Length];
532                 used[1] = new int[isUsed.Length];
533                 for (int i = 0; i < isUsed.Length; i++)
534                 {
535                     used[0][i] = isUsed[i];
536                     used[1][i] = isUsed[i];
537                 }
538                 if (info.type == 0)
539                 {
540                     used[0][info.splitIndex] = 1;
541                     node.feature_Type = ("离散");
542                 }
543                 else
544                 {
545                     //used[info.splitIndex] = 0;
546                     node.feature_Type = ("连续");
547                 }
548                 List<int>[] rowIndex = info.temp;
549                 List<String> features = info.features;
550                 Node node1 = new Node();
551                 Node node2 = new Node();
552                 node1.setClassCount(info.class_Count[0]);
553                 node2.setClassCount(info.class_Count[1]);
554                 node1.rowCount = info.temp[0].Count;
555                 node2.rowCount = info.temp[1].Count;
556                 node1.deep = deep + 1;
557                 node2.deep = deep + 1;
558                 node1 = findBestSplit(node1, info.temp[0],used[0]);
559                 node2 = findBestSplit(node2, info.temp[1], used[1]);
560                 node.leafNode_Count = (node1.leafNode_Count
561 
562 +node2.leafNode_Count);
563                 node.leafWrong = (node1.leafWrong+node2.leafWrong);
564                 node.features = (features);
565                 childNode.Add(node1);
566                 childNode.Add(node2);
567                 node.childNodes = childNode;
568                 #endregion
569                 return node;
570             }
571             catch (Exception e)
572             {
573                 Console.WriteLine(e.StackTrace);
574                 return node;
575             }
576         }
577         /// <summary>
578         /// GINI值
579         /// </summary>
580         /// <param name="counts"></param>
581         /// <param name="countAll"></param>
582         /// <returns></returns>
583         public static double getGini(double[] counts, int countAll)
584         {
585             double Gini = 1;
586             for (int i = 0; i < counts.Length; i++)
587             {
588                 Gini = Gini - Math.Pow(counts[i] / countAll, 2);
589             }
590             return Gini;
591         }
592         #region CCP剪枝
593         public static void getSeries(Node node)
594         {
595             Stack<Node> nodeStack = new Stack<Node>();
596             if (node != null)
597             {
598                 nodeStack.Push(node);
599             }
600             if (node.feature_Type == "result")
601                 return;
602             List<Node> childs = node.childNodes;
603             for (int i = 0; i < childs.Count; i++)
604             {
605                 getSeries(node);
606             }
607         }
608         /// <summary>
609         /// 遍历剪枝
610         /// </summary>
611         /// <param name="node"></param>
612         public static Node getNode1(Node node, Node nodeCut)
613         {
614             
615             //List<Node> childNodes = node.getChild();
616             //double min = 100000;
617             ////Node nodeCut = new Node();
618             //double temp = 0;
619             //for (int i = 0; i < childNodes.Count; i++)
620             //{
621             //    if (childNodes[i].getType() != "result")
622             //    {
623             //        //if (!cutTree(childNodes[i]))
624             //        temp = min;
625             //        min = cutTree(childNodes[i], min);
626             //        if (min < temp)
627             //            nodeCut = childNodes[i];
628             //        getNode1(childNodes[i], nodeCut);
629             //    }
630             //}
631             //node.setChildNode(childNodes);
632             return null;
633         }
634         /// <summary>
635         /// 对每一个节点剪枝
636         /// </summary>
637         public static double cutTree(Node node, double minA)
638         {
639             int rowCount = node.rowCount;
640             double leaf = node.getErrorCount();
641             double[] values = getError1(node, 0, 0);
642             double treeWrong = values[0];
643             double son = values[1];
644             double rate = (leaf - treeWrong) / (son - 1);
645             if (minA > rate)
646                 minA = rate;
647             //double var = Math.Sqrt(treeWrong * (1 - treeWrong / 
648 
649 rowCount));
650             //double panbie = treeWrong + var - leaf;
651             //if (panbie > 0)
652             //{
653             //    node.setFeatureType("result");
654             //    node.setChildNode(null);
655             //    int result = (node.getResult() + 1);
656             //    node.setFeatures(new List<String>() { "" + result 
657 
658 });
659             //    //return true;
660             //}
661             return minA;
662         }
663         /// <summary>
664         /// 获得子树的错误个数
665         /// </summary>
666         /// <param name="node"></param>
667         /// <returns></returns>
668         public static double[] getError1(Node node, double treeError, 
669 
670 double son)
671         {
672             if (node.feature_Type == "result")
673             {
674 
675                 double error = node.getErrorCount();
676                 son++;
677                 return new double[] { treeError + error, son };
678             }
679             List<Node> childNode = node.childNodes;
680             for (int i = 0; i < childNode.Count; i++)
681             {
682                 double[] values = getError1(childNode[i], treeError, 
683 
684 son);
685                 treeError = values[0];
686                 son = values[1];
687             }
688             return new double[] { treeError, son };
689         }
690         #endregion
CART核心代码

总结:

(1)CART是一棵二叉树,每一次分裂会产生两个子节点,对于连续性的数据,直接采用与C4.5相似的处理方法,对于离散型数据,选择最优的两种离散值组合方法。

(2)CART既能是分类数,又能是二叉树。如果是分类树,将选择能够最小化分裂后节点GINI值的分裂属性;如果是回归树,选择能够最小化两个节点样本方差的分裂属性。

(3)CART跟C4.5一样,需要进行剪枝,采用CCP(代价复杂度的剪枝方法)。

你可能感兴趣的:(决策树系列(五)——CART)