CART(Classification And Regression Tree)算法是一种决策树分类方法。
它采用一种二分递归分割的技术,分割方法采用基于最小距离的基尼指数估计函数,将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。
如果目标变量是离散变量,则是classfication Tree。
分类树是使用树结构算法将数据分成离散类的方法。
如果目标是连续变量,则是Regression Tree。
CART树是二叉树,不像多叉树那样形成过多的数据碎片。
(1)将训练样本进行递归地划分自变量空间进行建树
(2)用验证数据进行剪枝。
分别取X变量各值的不同组合,将其分到树的左枝或右枝,并对不同组合而产生的树,进行评判,找出最佳组合。如果只有两个取值,好办,直接根据这两个值就可以划分树。取值多于两个的情况就复杂一些了,如变量年纪,其值有“少年”、“中年”、“老年”,则分别生产{少年,中年}和{老年},{上年、老年}和{中年},{中年,老年}和{少年},这三种组合,最后评判对目标区分最佳的组合。因为CART二分的特性,当训练数据具有两个以上的类别,CART需考虑将目标类别合并成两个超类别,这个过程称为双化。这里可以说一个公式,n个属性,可以分出(2^n-2)/2种情况。
首先将值排序,分别取其两相邻值的平均值点作为分隔点,将树一分成左枝和右枝,不断扫描,进而判断最佳分割点。特征值大于分裂值就走左子树,或者就走右子树。
这里有一个问题,这次选中的分裂属性在下次还可以被选择吗?对于离散变量XD,如果XD只有两种取值,那么在这一次分裂中,根据XD分裂后,左子树中的subDataset中每个数据的XD属性一样,右子树中的subDataset中每个数据的XD属性也一样,所以在这个节点以后,XD都不起作用了,就不用考虑XD了。XD取3种,4种。。。的情况大家自己想想,不难想明白。至于连续变量XC,离散化后相当于一个可以取n个值的离散变量,按刚刚离散变量的情况分析。除非XC的取值都一样,否则这次用了XC作为分裂属性,下次还要考虑XC。
其思想是,让组内方差最小,对应组间方差最大,这样两组,也即树分裂的左枝和右枝差异化最大。
通过以上不纯度指标,分别计算每个变量的各种切分/组合情况,找出该变量的最佳值组合/切分点;再比较各个变量的最佳值组合/切分点,最终找出最佳变量和该变量的最佳值组合/切分点
整个树的生长是一个递归过程,直到终止条件
(1)节点是纯结点,即所有的记录的目标变量值相同
(2)树的深度达到了预先指定的最大值
(3)混杂度的最大下降值小于一个预先指定的值
(4)节点的记录量小于预先指定的最小节点记录量
(5)一个节点中的所有记录其预测变量值相同
直观的情况,当节点包含的数据记录都属于同一个类别时就可以终止分裂了。这只是一个特例,更一般的情况我们计算χ2值来判断分类条件和类别的相关程度,当χ2很小时说明分类条件和类别是独立的,即按照该分类条件进行分类是没有道理的,此时节点停止分裂。注意这里的“分类条件”是指按照GINI_Gain最小原则得到的“分类条件”。
终止条件(3)混杂度的最大下降值小于一个预先指定的值,该枝的分化即停止。所有枝节的分化都停止后,树形模型即成。其实你也可以不使用这个终止条件,让树生长到最大,因为CART有剪枝算法。
这里面误分类成本和先验概率是需要提前设定好的参数。这里为node标定label如果考虑一些unbalanced data,比如训练样本里有100个正样本,只有1个负样本,这样的数据就是unbalanced,就不能简单的majority归类了。上面的这个mark label的方法对不均衡数据就有一定的鲁棒性。
要注意对于每一个树结点,不管是否叶子结点,该node都要标上label,因为后面剪枝时非叶节点可能变为叶节点。
树生长完之后就是剪枝,剪枝非常重要。剪枝目的是避免决策树过拟合(Overfitting)样本。在一般的数据集中,过拟合的决策树的错误率比经过简化的决策树的错误率要高。
这一部分参考http://blog.csdn.net/u010159842/article/details/46458973
Cost-Complexity Pruning(CCP、代价复杂度)
CCP方法包含两个步骤:
1:从原始决策树T0开始生成一个子树序列{T0、T1、T2、...、Tn},其中Ti+1是从Ti总产生,Tn为根节点
2:从子树序列中,根据树的真实误差估计选择最佳决策树。
对于分类回归树中的每一个非叶子节点计算它的表面误差率增益值α。
是子树中包含的叶子节点个数;
是节点t的误差代价,如果该节点被剪枝;
r(t)是节点t的误差率;
p(t)是节点t上的数据占所有数据的比例。
是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。
比如有个非叶子节点t4如图所示:
比如有个非叶子节点t4如图所示:
已知所有的数据总共有60条,则节点t4的节点误差代价为:
子树误差代价为:
以t4为根节点的子树上叶子节点有3个,最终:
找到α值最小的非叶子节点,令其左右孩子为NULL。当多个非叶子节点的α值同时达到最小时,取最大的进行剪枝。
好了,再来看一个例子
很明白了吧
用一幅图解释一下
29-30之间的水平线以下的几个点代表的树都满足:
但箭头所指的树的叶节点最少,所以选择这棵树作为best tree。
对于某些采样数据,可能会缺少属性值。在这种情况下,处理缺少属性值的通常做法是赋予该属性的常见值,或者属性均值。另外一种比较好的方法是为该属性的每个可能值赋予一个概率,即将该属性以概率形式赋值。例如给定Boolean属性B,已知采样数据有12个B=0和88个B=1实例,那么在赋值过程中,B属性的缺失值被赋值为B(0)=0.12、B(1)=0.88;所以属性B的缺失值以12%概率被分到False的分支,以88%概率被分到True的分支。这种处理的目的是计算信息增益,使得这种属性值缺失的样本也能处理。
(1)分类树:最终叶子中概率最大的类
(2)回归树:最终叶子的均值或者中位数
// cart.cpp : 定义控制台应用程序的入口点。 // #include "stdafx.h" #include<vector> #include<set> #include<algorithm> #include<iostream> #include<iterator> #include<fstream> #include<string> #include<map> /*******************************************/ /************author Marshall****************/ /**********date 2015.10.3*******************/ /**************version 1.0******************/ /************copyright reserved*************/ /*******************************************/ using namespace std; class cart { private: vector<int>nums_of_value_each_discreteAttri; int num_of_continuousAttri; int ContinuousAttriNums; int labelNums;//how many kinds of label unsigned int CL_max_height; //double miniumginigain;//not need,we have prune method //define the record class Record { public: vector<int>discrete_attri;//for each discrete attribute,it's value can be 0,1...increased by 1 vector<double>continuous_attti; int label;//0,1,2... }; //define the node struct CartNode { vector<int>remianDiscreteAttriID; int selectedAttriID; vector<int>selectedDiscreteAttriValues; bool isSelectedAttriIDDiscrete; double continuousAttriPartitionValue;// int label;//if the record drop in this node,its' label should be int height;//current node's height vector<int>labelcount;//a counter for the records' label that current node holds double alpha;//for nonleaf,for prune int record_number;//该节点上涵盖的记录个数 CartNode*lnode, *rnode; CartNode() { label = -1; selectedAttriID = -1; isSelectedAttriIDDiscrete = true; lnode = rnode = NULL; record_number = 0; } }; CartNode*root; //double threshold; private: //calculate gini index,for classify double calGiniIndex(vector<int>&subdatasetbyID, const vector<Record>*dataset, CartNode*node = NULL); double calSquaredresiduals();//calculate squaredresiduals,for regression void CL_split_dataset(); void RE_split_dataset(); void CL_trim(const vector<Record>*validationdataset); void RE_trim(); //void make_discrete(); //pair.first is majority label in subdataset,pair.second is it's number int allthesame(vector<int>&subdatasetbyID, const vector<Record>*dataset); /*如果某特征取值有3个,那么二分序列组合就有3种,4个取值就有7种组合,5个取值就有15种组合*/ vector<pair<vector<int>, vector<int>>>make_two_heap(const int kk); pair<vector<int>, vector<int>>split_dataset(const int&selectedDiscreteAttriID, vector<int>&selected, const vector<int>&subdatasetbyID, const vector<Record>*dataset); pair<vector<int>, vector<int>>split_dataset(const int&selectedContiuousAttriID, const double partition, const vector<int>&subdatasetbyID, const vector<Record>*dataset); CartNode* copytree(CartNode*src, CartNode*dst);//deepcopy of a tree,dst should be NUll void copynode(CartNode*src, CartNode*dst); void cal_alpha(CartNode*node); vector<CartNode*>getLeaf(CartNode*node); void destroyTree(CartNode*node); int labelNode(CartNode*node); void create_root(); void build_tree_classify(vector<int>&subdatasetbyID, CartNode*node, const vector<Record>*dataset); void build_tree_regression(); public: void load_adult_dataset(); int CART_classify(const Record dataset, CartNode*root = NULL); void CART_regression(); void CART_trian(const vector<Record>*dataset, const vector<Record>*validationdataset); void CART_trian() { CART_trian(traindataset, validatedataset); } void set_paras(); ~cart() { destroyTree(root); if (traindataset != NULL) delete traindataset; if (validatedataset != NULL) delete validatedataset; } vector<Record>*traindataset;//as it's name vector<Record>*validatedataset; vector<Record>*testdataset; void test(CartNode*node); void test(); }; void cart::test(CartNode*node) { int errorNum = 0; for (int j = 0; j < testdataset->size(); j++) { errorNum += CART_classify((*testdataset)[j], node) == (*testdataset)[j].label ? 0 : 1; } cout << "测试集上的错误率为" << double(errorNum) / testdataset->size(); } void cart::test() { test(this->root); } void cart::set_paras() { CL_max_height = 6; } void cart::CART_trian(const vector<Record>*dataset, const vector<Record>*validationdataset) { create_root(); set_paras(); vector<int>subset; for (int i = 0; i < dataset->size(); i++) subset.push_back(i); build_tree_classify(subset, root, dataset); CL_trim(validationdataset); } void cart::destroyTree(CartNode*treeroot) { _ASSERTE(treeroot != NULL); vector<CartNode*>pool, que; que.push_back(treeroot); while (!que.empty()) { CartNode*node = que.back(); que.pop_back(); pool.push_back(node); if (node->lnode != NULL) { _ASSERTE(node->rnode != NULL); pool.push_back(node->lnode); pool.push_back(node->rnode); } } for (int i = 0; i < pool.size(); i++) delete pool[i]; } void cart::copynode(CartNode*src, CartNode*dst) { _ASSERTE(dst != NULL); _ASSERTE(src != NULL); dst->alpha = src->alpha; dst->continuousAttriPartitionValue = src->continuousAttriPartitionValue; dst->height = src->height; dst->isSelectedAttriIDDiscrete = src->isSelectedAttriIDDiscrete; dst->label = src->label; dst->labelcount = src->labelcount; dst->record_number = src->record_number; dst->remianDiscreteAttriID = src->remianDiscreteAttriID; dst->selectedAttriID = src->selectedAttriID; dst->selectedDiscreteAttriValues = src->selectedDiscreteAttriValues; } //implementation of tree copy cart::CartNode* cart::copytree(CartNode*Srctreeroot, CartNode*Dsttreeroot) { _ASSERTE(Dsttreeroot == NULL); _ASSERTE(Srctreeroot != NULL); vector<CartNode*>pool, parentpool; Dsttreeroot = new CartNode; copynode(Srctreeroot, Dsttreeroot); if (Srctreeroot->lnode == NULL) { _ASSERTE(Srctreeroot->rnode == NULL); return Dsttreeroot; } pool.push_back(Srctreeroot->lnode); pool.push_back(Srctreeroot->rnode); parentpool.push_back(Dsttreeroot); bool lnodeflag = false; while (!pool.empty()) { CartNode*node = pool.back(); pool.pop_back(); CartNode*newnode = new CartNode; copynode(node, newnode); if (!lnodeflag) parentpool.back()->rnode = newnode; else parentpool.back()->lnode = newnode; if (node->lnode != NULL) { _ASSERTE(node->rnode != NULL); if (lnodeflag) parentpool.pop_back(); lnodeflag = false; pool.push_back(node->lnode); pool.push_back(node->rnode); parentpool.push_back(newnode); } else { if (lnodeflag) parentpool.pop_back(); else lnodeflag = !lnodeflag; } } _ASSERTE(parentpool.empty()); _ASSERTE(Dsttreeroot); return Dsttreeroot; } int cart::CART_classify(const Record rd, CartNode*treeroot) { if (treeroot == NULL) treeroot = this->root; CartNode*node = treeroot; while (true) { if (node->lnode == NULL) { _ASSERTE(node->rnode == NULL); return node->label; } if (node->isSelectedAttriIDDiscrete) { if (find(node->selectedDiscreteAttriValues.begin(), node->selectedDiscreteAttriValues.end(), rd.discrete_attri[node->selectedAttriID]) == node->selectedDiscreteAttriValues.end()) { node = node->rnode; } else { node = node->lnode; } } else { if (rd.continuous_attti[node->selectedAttriID] >= node->continuousAttriPartitionValue) { node = node->rnode; } else { node = node->lnode; } } } //should not run here _ASSERTE(false); } void cart::CL_trim(const vector<Record>*validationdataset) { vector<CartNode*>candidateBestTree; CartNode*curretroot = root; while (curretroot->lnode != NULL)//&&root->rnode!=NULL { vector<CartNode*>pool; pool.push_back(curretroot); double min_alpha = 10000000; CartNode*tobecut = NULL; while (!pool.empty()) { CartNode*node = pool.back(); pool.pop_back(); if (node->lnode != NULL) { _ASSERTE(node->rnode != NULL); cal_alpha(node); if (node->alpha < min_alpha) { min_alpha = node->alpha; tobecut = node; } pool.push_back(node->rnode); pool.push_back(node->lnode); } } _ASSERTE(tobecut != NULL); //then delete tobecut's child and son node vector<CartNode*>alltodel, temppool; temppool.push_back(tobecut); while (!temppool.empty()) { CartNode*nn = temppool.back(); temppool.pop_back(); alltodel.push_back(nn); if (nn->lnode != NULL) { _ASSERTE(nn->rnode != NULL); temppool.push_back(nn->lnode); temppool.push_back(nn->rnode); } } alltodel.erase(find(alltodel.begin(), alltodel.end(), tobecut)); for (int i = 0; i < alltodel.size(); i++) delete alltodel[i]; tobecut->lnode = tobecut->rnode = NULL; candidateBestTree.push_back(curretroot); CartNode*nextroot = NULL; nextroot = copytree(curretroot, nextroot); _ASSERTE(nextroot); curretroot = nextroot; } //get the best tree int minError = validationdataset->size(); CartNode*besttree = NULL; int th = -1; vector<int>candidateBestTreeErrorNums; for (int i = 0; i < candidateBestTree.size(); i++) { int errorNum = 0; for (int j = 0; j < validationdataset->size(); j++) { errorNum += CART_classify((*validationdataset)[j], candidateBestTree[i]) == (*validationdataset)[j].label ? 0 : 1; } //error /= (*validationdataset).size(); candidateBestTreeErrorNums.push_back(errorNum); if (errorNum < minError) { minError = errorNum; th = i; } } test(candidateBestTree[th]); double SE = sqrt(double(minError*(validationdataset->size() - minError)) / validationdataset->size()); for (int i = candidateBestTree.size() - 1; i >= 0; i--) { if (candidateBestTreeErrorNums[i] <= minError + SE) { besttree = candidateBestTree[i]; th = i; break; } } candidateBestTree.erase(candidateBestTree.begin() + th); for (int i = 0; i < candidateBestTree.size(); i++) destroyTree(candidateBestTree[i]); _ASSERTE(besttree != NULL); root = besttree; cout << "剪枝后在验证集上的错误为" << (double)candidateBestTreeErrorNums[th] / validationdataset->size() << endl; } void cart::cal_alpha(CartNode*node) { _ASSERTE(node->lnode != NULL&&node->rnode != NULL); int max_nodelabel = -1; for (int i = 0; i < labelNums; i++) { if (node->labelcount[i] > max_nodelabel) { max_nodelabel = node->labelcount[i]; } } double Rt = double(max_nodelabel) / node->record_number*node->record_number / traindataset->size(); double RTt = 0; vector<CartNode*>leafpool = getLeaf(node); for (int i = 0; i < leafpool.size(); i++) { RTt += double(leafpool[i]->record_number - leafpool[i]->labelcount[leafpool[i]->label]) / traindataset->size(); } node->alpha = (Rt - RTt) / (leafpool.size() - 1); } vector<cart::CartNode*>cart::getLeaf(CartNode*node) { vector<CartNode*>leafpool, que; que.push_back(node); while (!que.empty()) { CartNode*nn = que.back(); que.pop_back(); if (nn->lnode != NULL) que.push_back(nn->lnode); else { _ASSERTE(nn->rnode == NULL); if (find(leafpool.begin(), leafpool.end(), nn) == leafpool.end()) leafpool.push_back(nn); } if (nn->rnode != NULL) que.push_back(nn->rnode); else { _ASSERTE(nn->lnode == NULL); if (find(leafpool.begin(), leafpool.end(), nn) == leafpool.end()) leafpool.push_back(nn); } } return leafpool; } pair<vector<int>, vector<int>>cart::split_dataset(const int&selectedDiscreteAttriID, vector<int>&selected, const vector<int>&subdatasetbyID, const vector<Record>*dataset) { vector<int>aa, bb; for (int i = 0; i < subdatasetbyID.size(); i++) { if (find(selected.begin(), selected.end(), (*dataset)[subdatasetbyID[i]]. discrete_attri[selectedDiscreteAttriID]) == selected.end()) { bb.push_back(subdatasetbyID[i]); } else aa.push_back(subdatasetbyID[i]); } return pair<vector<int>, vector<int>>(aa, bb); } pair<vector<int>, vector<int>>cart::split_dataset(const int&selectedContiuousAttriID, const double partition, const vector<int>&subdatasetbyID, const vector<Record>*dataset) { vector<int>aa, bb; for (int i = 0; i < subdatasetbyID.size(); i++) { if ((*dataset)[subdatasetbyID[i]].continuous_attti[selectedContiuousAttriID] >= partition) { bb.push_back(subdatasetbyID[i]); } else aa.push_back(subdatasetbyID[i]); } return pair<vector<int>, vector<int>>(aa, bb); } set<set<int>>solu; void select(set<int>&selected, vector<int>&remain, int toselect) { if (selected.size() == toselect) { if (solu.find(selected) == solu.end()) { solu.insert(selected); //for (set<int>::iterator it = selected.begin(); it != selected.end(); it++) // cout << *it << ","; //cout << endl; } return; } for (int i = 0; i < remain.size(); i++) { vector<int> re = remain; set<int>se = selected; se.insert(re[i]); re.erase(re.begin() + i); select(se, re, toselect); } } void Combination(vector<int>remain, int toselect)//组合 { solu.clear(); set<int>selected; select(selected, remain, toselect); //cout << "共有" << solu.size() << "种组合" << endl; } vector<pair<vector<int>, vector<int>>>cart::make_two_heap(const int kk) { vector<pair<vector<int>, vector<int>>>toret; int len = nums_of_value_each_discreteAttri[kk]; set<set<int>>re; vector<int>remain; for (int i = 0; i < len; i++) remain.push_back(i); for (int i = 1; i < len / 2 + 1; i++) { Combination(vector<int>(remain), i); re.insert(solu.begin(), solu.end()); } for (set<set<int>>::iterator it = re.begin(); it != re.end(); it++) { vector<int>aa, bb;//bb(*it); set_difference(it->begin(), it->end(), remain.begin(), remain.end(), inserter(aa, aa.begin())); bb.insert(bb.begin(), it->begin(), it->end()); toret.push_back(pair<vector<int>, vector<int>>(aa, bb)); } return toret; } void cart::create_root() { if (root == NULL) { root = new CartNode; for (int i = 0; i < nums_of_value_each_discreteAttri.size(); i++) root->remianDiscreteAttriID.push_back(i); root->height = 1; } } int cart::allthesame(vector<int>&subdatasetbyID, const vector<Record>*dataset) { vector<int>count(labelNums); int label = ((*dataset)[subdatasetbyID[0]]).label; for (int i = 1; i < subdatasetbyID.size(); i++) if (((*dataset)[subdatasetbyID[i]]).label != label) return -1; return label; } //build classify tree recursively void cart::build_tree_classify(vector<int>&subdatasetbyID, CartNode*node, const vector<Record>*dataset) { node->record_number = subdatasetbyID.size(); double basegini = calGiniIndex(subdatasetbyID, dataset, node); int currentlabel = allthesame(subdatasetbyID, dataset); if (currentlabel >= 0) { node->label = currentlabel; return; } if (node->height >= CL_max_height) { node->label = labelNode(node); return; } node->label = labelNode(node); double mingini = 10000000000; int selected = -1; bool isSelectedDiscrete = true; vector<int>selectedDiscreteAttriValues; pair<vector<int>, vector<int>>splited_subdataset; bool lnodeDecreaseDiscreteAttri = false;//is node's lnode's discrete attribute nums decrease bool rnodeDecreaseDiscreteAttri = false; //for discrete features,calculate giniindex for (int i = 0; i < node->remianDiscreteAttriID.size(); i++) { if (nums_of_value_each_discreteAttri[node->remianDiscreteAttriID[i]] > 2) { vector<pair<vector<int>, vector<int>>>bipart = make_two_heap(node->remianDiscreteAttriID[i]); for (int j = 0; j < bipart.size(); j++) { pair<vector<int>, vector<int>>two_subdataset = split_dataset( node->remianDiscreteAttriID[i], bipart[i].first, subdatasetbyID, dataset); if (two_subdataset.first.size() > 0 && two_subdataset.second.size() > 0) { double gini1 = calGiniIndex(two_subdataset.first, dataset); double gini2 = calGiniIndex(two_subdataset.second, dataset); double gini = double(two_subdataset.first.size()) / subdatasetbyID.size()*gini1 + double(two_subdataset.second.size()) / subdatasetbyID.size()*gini2; if (gini < mingini) { if (bipart[i].first.size() == 1) lnodeDecreaseDiscreteAttri = true; else lnodeDecreaseDiscreteAttri = false; if (bipart[i].second.size() == 1) rnodeDecreaseDiscreteAttri = true; else rnodeDecreaseDiscreteAttri = false; mingini = gini; selected = node->remianDiscreteAttriID[i]; splited_subdataset = two_subdataset; selectedDiscreteAttriValues = bipart[i].first; } } } } else { vector<int>aa; aa.push_back(0); pair<vector<int>, vector<int>>two_subdataset = split_dataset(node->remianDiscreteAttriID[i], aa, subdatasetbyID, dataset); if (two_subdataset.first.size() > 0 && two_subdataset.second.size() > 0) { double gini1 = calGiniIndex(two_subdataset.first, dataset); double gini2 = calGiniIndex(two_subdataset.second, dataset); double gini = double(two_subdataset.first.size()) / subdatasetbyID.size()*gini1 + double(two_subdataset.second.size()) / subdatasetbyID.size()*gini2; if (gini < mingini) { mingini = gini; selected = node->remianDiscreteAttriID[i]; splited_subdataset = two_subdataset; lnodeDecreaseDiscreteAttri = true; rnodeDecreaseDiscreteAttri = true; selectedDiscreteAttriValues.clear(); selectedDiscreteAttriValues.push_back(0); } } } } // 利用函数对象实现升降排序 struct CompNameEx{ CompNameEx(bool asce, int k, const vector<Record>*dataset) : asce_(asce), kk(k), dataset(dataset) {} bool operator()(int const& pl, int const& pr) { return asce_ ? (*dataset)[pl].continuous_attti[kk] < (*dataset)[pr].continuous_attti[kk] : (*dataset)[pr].continuous_attti[kk] < (*dataset)[pl].continuous_attti[kk]; // 《Eff STL》条款21: 永远让比较函数对相等的值返回false } private: bool asce_; int kk; const vector<Record>*dataset; }; //for continuous features,calculate giniindex double partitionpoint; for (int i = 0; i < ContinuousAttriNums; i++) { sort(subdatasetbyID.begin(), subdatasetbyID.end(), CompNameEx(true, i, dataset)); for (int j = 0; j < subdatasetbyID.size() - 1; j++) { double partition = 0.5*(*dataset)[subdatasetbyID[j]].continuous_attti[i] + 0.5*(*dataset)[subdatasetbyID[j + 1]].continuous_attti[i]; pair<vector<int>, vector<int>>two_subdataset = split_dataset(i, partition, subdatasetbyID, dataset); if (two_subdataset.first.size() > 0 && two_subdataset.second.size() > 0) { double gini1 = calGiniIndex(two_subdataset.first, dataset); double gini2 = calGiniIndex(two_subdataset.second, dataset); double gini = double(two_subdataset.first.size()) / subdatasetbyID.size()*gini1 + double(two_subdataset.second.size()) / subdatasetbyID.size()*gini2 + log(double(subdatasetbyID.size() - 2) / dataset->size()) / log(2.0); if (gini < mingini) { partitionpoint = partition; mingini = gini; selected = i; isSelectedDiscrete = false; splited_subdataset = two_subdataset; } } } } //we have prune,so regardless of ginigain //double ginigain = basegini - mingini;//if not greater than miniumginigain;current node should not grow if (splited_subdataset.first.size() > 0 && splited_subdataset.second.size() > 0)//&&ginigain>miniumginigain) { CartNode*lchild = new CartNode; CartNode*rchild = new CartNode; node->lnode = lchild; node->rnode = rchild; lchild->height = node->height + 1; rchild->height = node->height + 1; lchild->remianDiscreteAttriID = node->remianDiscreteAttriID; rchild->remianDiscreteAttriID = node->remianDiscreteAttriID; node->selectedAttriID = selected; if (isSelectedDiscrete) { if (lnodeDecreaseDiscreteAttri) { lchild->remianDiscreteAttriID.erase(find(lchild-> remianDiscreteAttriID.begin(), lchild->remianDiscreteAttriID.end(), selected)); } if (rnodeDecreaseDiscreteAttri) { rchild->remianDiscreteAttriID.erase(find(rchild-> remianDiscreteAttriID.begin(), rchild->remianDiscreteAttriID.end(), selected)); } node->selectedDiscreteAttriValues = selectedDiscreteAttriValues; } else { node->isSelectedAttriIDDiscrete = false; node->continuousAttriPartitionValue = partitionpoint; } //recursively call build_tree_classify() build_tree_classify(splited_subdataset.first, lchild, dataset); build_tree_classify(splited_subdataset.second, rchild, dataset); } } double cart::calGiniIndex(vector<int>&subdatasetbyID, const vector<Record>*dataset, CartNode*node) { _ASSERTE(subdatasetbyID.size() > 0); _ASSERTE(dataset != NULL); vector<int>count; count.resize(labelNums); for (int i = 0; i < subdatasetbyID.size(); i++) { count[((*dataset)[subdatasetbyID[i]]).label]++; } if (node != NULL) { node->labelcount = count; node->record_number = subdatasetbyID.size(); } vector<double> probalblity; probalblity.resize(labelNums); double re = 1; for (int i = 0; i < labelNums; i++) { probalblity[i] = double(count[i]) / subdatasetbyID.size(); re -= pow(probalblity[i], 2); } _ASSERTE(re >= 0); return re; } int cart::labelNode(CartNode*node) { int label = -1; double maxpro = 0; for (int i = 0; i < labelNums; i++) { double temppro = double(node->labelcount[i]) / node->record_number; temppro /= double(root->labelcount[i]) / root->record_number; if (temppro > maxpro) { maxpro = temppro; label = i; } } _ASSERTE(label >= 0); return label; } int split(const std::string& str, std::vector<std::string>& ret_, std::string sep = ",") { if (str.empty()) { return 0; } std::string tmp; std::string::size_type pos_begin = str.find_first_not_of(sep); std::string::size_type comma_pos = 0; while (pos_begin != std::string::npos) { comma_pos = str.find(sep, pos_begin); if (comma_pos != std::string::npos) { tmp = str.substr(pos_begin, comma_pos - pos_begin); pos_begin = comma_pos + sep.length(); } else { tmp = str.substr(pos_begin); pos_begin = comma_pos; } if (!tmp.empty()) { ret_.push_back(tmp); tmp.clear(); } } return 0; } //说明,因为education,workclass,marital-status,occupation,native country属性太多,不作考虑 void cart::load_adult_dataset() { vector<Record>*traindataset;//as it's name vector<Record>*validatedataset; string filename = "adult.data"; ifstream infile(filename.c_str()); string temp; cout << endl; int count = 0; //vector<vector<std::string>>ss; traindataset = new vector < Record > ; validatedataset = new vector < Record > ; this->traindataset = traindataset; this->validatedataset = validatedataset; testdataset = new vector < Record > ; //Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked /*map<string, int>workclass; workclass["Private"] = 0; workclass["Self-emp-not-inc"] = 1; workclass["Self-emp-inc"] = 2; workclass["Federal-gov"] = 3; workclass["Local-gov"] = 4; workclass["State-gov"] = 5; workclass["Without-pay"] = 6; workclass["Never-worked"] = 7;*/ //education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, // 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. /*map<string, int>education; education["Bachelors"] = 0; education["Some-college"] = 1; education["11th"] = 2; education["HS-grad"] = 3; education["Prof-school"] = 4; education["Assoc-acdm"] = 5; education["Assoc-voc"] = 6; education["9th"] = 7; education["7th-8th"] = 8; education["12th"] = 9; education["Masters"] = 10; education["1st-4th"] = 11; education["10th"] = 12; education["Doctorate"] = 13; education["5th-6th"] = 14; education["Preschool"] = 15; */ //marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, // Married-spouse-absent, Married-AF-spouse. /*map<string, int>marital_status; marital_status["Married-civ-spouse"] = 0; marital_status["Divorced"] = 1; marital_status["Never-married"] = 2; marital_status["Separated"] = 3; marital_status["Widowed"] = 4; marital_status["Married-spouse-absent"] = 5; marital_status["Married-AF-spouse"] = 6;*/ //occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, //Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, // Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces. /*map<string, int>occupation; occupation["Tech-support"] = 0; occupation["Craft-repair"] = 1; occupation["Other-service"] = 2; occupation["Sales"] = 3; occupation["Exec-managerial"] = 4; occupation["Prof-specialty"] = 5; occupation["Handlers-cleaners"] = 6; occupation["Machine-op-inspct"] = 7; occupation["Adm-clerical"] = 8; occupation["Farming-fishing"] = 9; occupation["Transport-moving"] = 10; occupation["Priv-house-serv"] = 11; occupation["Protective-serv"] = 12; occupation["Armed-Forces"] = 13; */ //relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. map<string, int>relationship; relationship["Wife"] = 0; relationship["Own-child"] = 1; relationship["Husband"] = 2; relationship["Not-in-family"] = 3; relationship["Other-relative"] = 4; relationship["Unmarried"] = 5; //race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black. map<string, int>race; race["White"] = 0; race["Asian-Pac-Islander"] = 1; race["Amer-Indian-Eskimo"] = 2; race["Other"] = 3; race["Black"] = 4; //sex: Female, Male. map<string, int>sex; sex["Female"] = 0; sex["Male"] = 1; //native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, //Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, // Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, //Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, // Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, //Trinadad&Tobago, Peru, Hong, Holand-Netherlands. map<string, int>label; label["<=50K"] = 0; label[">50K"] = 1; while (getline(infile, temp) && count < 7000) { Record rd; rd.continuous_attti.resize(6); rd.discrete_attri.resize(3); //cout << temp << endl; std::vector<std::string>re; split(temp, re, std::string(", ")); bool desert = false; if (re.size() == 15) { /*age: continuous. workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked. fnlwgt: continuous. education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. education-num: continuous. marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse. occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces. relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black. sex: Female, Male. capital-gain: continuous. capital-loss: continuous. hours-per-week: continuous. native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.*/ //age continuous rd.continuous_attti[0] = atoi(re[0].c_str()); //workclass discrete /*if (workclass.find(re[1]) != workclass.end()) rd.discrete_attri[0] = workclass[re[1]]; else desert=true;*/ //fnlwgt: continuous rd.continuous_attti[1] = atoi(re[2].c_str()); //education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. /*if (education.find(re[3]) != education.end()) rd.discrete_attri[1] = education[re[3]]; else desert=true;*/ //education-num: continuous. rd.continuous_attti[2] = atoi(re[4].c_str()); //marital-status /*if (marital_status.find(re[5]) != marital_status.end()) rd.discrete_attri[1] = marital_status[re[5]]; else desert=true;*/ //relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. if (relationship.find(re[7]) != relationship.end()) rd.discrete_attri[0] = relationship[re[7]]; else desert = true; //race if (race.find(re[8]) != race.end()) rd.discrete_attri[1] = race[re[8]]; else desert = true; //sex if (sex.find(re[9]) != sex.end()) rd.discrete_attri[2] = sex[re[9]]; else desert = true; //capital-gain: continuous. rd.continuous_attti[3] = atoi(re[10].c_str()); //capital-loss: continuous. rd.continuous_attti[4] = atoi(re[11].c_str()); //hours-per-week: continuous rd.continuous_attti[5] = atoi(re[12].c_str()); if (label.find(re[14]) != label.end()) rd.label = label[re[14]]; else desert = true; if (!desert) if (count < 3500) { traindataset->push_back(rd); } else if (count < 4500) { validatedataset->push_back(rd); } else testdataset->push_back(rd); } count++; } ContinuousAttriNums = 6; labelNums = 2; int aa[3] = { 6, 5, 2 }; nums_of_value_each_discreteAttri.push_back(6); nums_of_value_each_discreteAttri.push_back(5); nums_of_value_each_discreteAttri.push_back(2); } int _tmain(int argc, _TCHAR* argv[]) { cart cart; cart.load_adult_dataset(); cart.CART_trian(); cart.test(); system("pause"); return 0; }