CART

一、为什么有CART回归树

  以前学过全局回归,顾名思义,就是指全部数据符合某种曲线。比如线性回归,多项式拟合(泰勒)等等。可是这些数学规律多强,硬硬地将全部数据逼近一些特殊的曲线。生活中的数据可是千变万化。那么,局部回归是一种合理地选择。在斯坦福大学NG的公开课中,他也提到局部回归的好处。其中,CART回归树就是局部回归的一种。

二、CART回归树的算法流程

CART_第1张图片

  注意到,(1)中两步优化,即选择最优切分变量和切分点。(i)如果给定x的切分点。那么可以马上求得中括号内的最优。(ii)对于切分点怎么确定,这里是用遍历的方法。

三、CART分类树

  实际上,CART分类树的生成树和ID3方法类似,只是这里用基尼指数代替了信息增益,定义

CART_第2张图片

CART_第3张图片

CART_第4张图片

CART_第5张图片

CART_第6张图片

四、CART剪枝算法流程

CART_第7张图片

 例子参考:http://www.cnblogs.com/zhangchaoyang/articles/2709922.html

比如:

当分类回归树划分得太细时,会对噪声数据产生过拟合作用。因此我们要通过剪枝来解决。剪枝又分为前剪枝和后剪枝:前剪枝是指在构造树的过程中就知道哪些节点可以剪掉,于是干脆不对这些节点进行分裂,在N皇后问题和背包问题中用的都是前剪枝,上面的χ2方法也可以认为是一种前剪枝;后剪枝是指构造出完整的决策树之后再来考查哪些子树可以剪掉。

在分类回归树中可以使用的后剪枝方法有多种,比如:代价复杂性剪枝、最小误差剪枝、悲观误差剪枝等等。这里我们只介绍代价复杂性剪枝法。

对于分类回归树中的每一个非叶子节点计算它的表面误差率增益值α。

是子树中包含的叶子节点个数;

是节点t的误差代价,如果该节点被剪枝;

r(t)是节点t的误差率;

p(t)是节点t上的数据占所有数据的比例。

是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。

比如有个非叶子节点t4如图所示:CART_第8张图片

 

已知所有的数据总共有60条,则节点t4的节点误差代价为:

子树误差代价为:

以t4为根节点的子树上叶子节点有3个,最终:

找到α值最小的非叶子节点,令其左右孩子为NULL。当多个非叶子节点的α值同时达到最小时,取最大的进行剪枝。

 

#include
#include
#include
#include<string>
#include
#include
#include<set>
#include
#include
#include
#include
 
using namespace std;
 
//置信水平取0.95时的卡方表
const double CHI[18]={0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962};
/*根据多维数组计算卡方值*/
template
double cal_chi(Comparable **arr,int row,int col){
    vector rowsum(row);
    vector colsum(col);
    Comparable totalsum=static_cast(0);
    //cout<<"observation"<
    for(int i=0;ii){
        for(int j=0;jj){
            //cout<
            totalsum+=arr[i][j];
            rowsum[i]+=arr[i][j];
            colsum[j]+=arr[i][j];
        }
        //cout<
    }
    double rect=0.0;
    //cout<<"exception"<
    for(int i=0;ii){
        for(int j=0;jj){
            double excep=1.0*rowsum[i]*colsum[j]/totalsum;
            //cout<
            if(excep!=0)
                rect+=pow(arr[i][j]-excep,2.0)/excep;
        }
        //cout<
    }
    return rect;
}
 
class MyTriple{
public:
    double first;
    int second;
    int third;
    MyTriple(){
        first=0.0;
        second=0;
        third=0;
    }
    MyTriple(double f,int s,int t):first(f),second(s),third(t){}
    bool operator< (const MyTriple &obj) const{
        int cmp=this->first-obj.first;
        if(cmp>0)
            return false;
        else if(cmp<0)
            return true;
        else{
            cmp=obj.second-this->second;
            if(cmp<0)
                return true;
            else
                return false;
        }
    }
};
 
typedef map<string,int> MAP_REST_COUNT;
typedef map<string,MAP_REST_COUNT> MAP_ATTR_REST;
typedef vector VEC_STATI;
 
const int ATTR_NUM=8;       //自变量的维度
vector<string> X(ATTR_NUM);
int rest_number;        //因变量的种类数,即类别数
vectorstring,int> > classes;      //把类别、对应的记录数存放在一个数组中
int total_record_number;        //总的记录数
vectorstring> > inputData;      //原始输入数据
 
class node{
public:
    node* parent;       //父节点
    node* leftchild;        //左孩子节点
    node* rightchild;       //右孩子节点
    string cond;        //分枝条件
    string decision;        //在该节点上作出的类别判定
    double precision;       //判定的正确率
    int record_number;      //该节点上涵盖的记录个数
    int size;       //子树包含的叶子节点的数目
    int index;      //层次遍历树,给节点标上序号
    double alpha;   //表面误差率的增加量
    node(){
        parent=NULL;
        leftchild=NULL;
        rightchild=NULL;
        precision=0.0;
        record_number=0;
        size=1;
        index=0;
        alpha=1.0;
    }
    node(node* p){
        parent=p;
        leftchild=NULL;
        rightchild=NULL;
        precision=0.0;
        record_number=0;
        size=1;
        index=0;
        alpha=1.0;
    }
    node(node* p,string c,string d):cond(c),decision(d){
        parent=p;
        leftchild=NULL;
        rightchild=NULL;
        precision=0.0;
        record_number=0;
        size=1;
        index=0;
        alpha=1.0;
    }
    void printInfo(){
        cout<<"index:"<"\tdecisoin:"<"\tprecision:"<"\tcondition:"<"\tsize:"<<size;
        if(parent!=NULL)
            cout<<"\tparent index:"<index;
        if(leftchild!=NULL)
            cout<<"\tleftchild:"<index<<"\trightchild:"<index;
        cout<<endl;
    }
    void printTree(){
        printInfo();
        if(leftchild!=NULL)
            leftchild->printTree();
        if(rightchild!=NULL)
            rightchild->printTree(); 
    }
};
 
int readInput(string filename){
    ifstream ifs(filename.c_str());
    if(!ifs){
        cerr<<"open inputfile failed!"<<endl;
        return -1;
    }
    map<string,int> catg;
    string line;
    getline(ifs,line);
    string item;
    istringstream strstm(line);
    strstm>>item;
    for(int i=0;ii){
        strstm>>item;
        X[i]=item;
    }
    while(getline(ifs,line)){
        vector<string> conts(ATTR_NUM+2);
        istringstream strstm(line);
        //strstm.str(line);
        for(int i=0;ii){
            strstm>>item;
            conts[i]=item;
            if(i==conts.size()-1)
                catg[item]++;
        }
        inputData.push_back(conts);
    }
    total_record_number=inputData.size();
    ifs.close();
    map<string,int>::const_iterator itr=catg.begin();
    while(itr!=catg.end()){
        classes.push_back(make_pair(itr->first,itr->second));
        itr++;
    }
    rest_number=classes.size();
    return 0;
}
 
/*根据inputData作出一个统计stati*/
void statistic(vectorstring> > &inputData,VEC_STATI &stati){
    for(int i=1;i1;++i){
        MAP_ATTR_REST attr_rest;
        for(int j=0;jj){
            string attr_value=inputData[j][i];
            string rest=inputData[j][ATTR_NUM+1];
            MAP_ATTR_REST::iterator itr=attr_rest.find(attr_value);
            if(itr==attr_rest.end()){
                MAP_REST_COUNT rest_count;
                rest_count[rest]=1;
                attr_rest[attr_value]=rest_count;
            }
            else{
                MAP_REST_COUNT::iterator iter=(itr->second).find(rest);
                if(iter==(itr->second).end()){
                    (itr->second).insert(make_pair(rest,1));
                }
                else{
                    iter->second+=1;
                }
            }
        }
        stati.push_back(attr_rest);
    }
}
 
/*依据某条件作出分枝时,inputData被分成两部分*/
void splitInput(vectorstring> > &inputData,int fitIndex,string cond,vectorstring> > &LinputData,vectorstring> > &RinputData){
    for(int i=0;ii){
        if(inputData[i][fitIndex+1]==cond)
            LinputData.push_back(inputData[i]);
        else
            RinputData.push_back(inputData[i]);
    }
}
 
void printStati(VEC_STATI &stati){
    for(int i=0;i){
        MAP_ATTR_REST::const_iterator itr=stati[i].begin();
        while(itr!=stati[i].end()){
            cout<first;
            MAP_REST_COUNT::const_iterator iter=(itr->second).begin();
            while(iter!=(itr->second).end()){
                cout<<"\t"<first<<"\t"<second;
                iter++;
            }
            itr++;
            cout<<endl;
        }
        cout<<endl;
    }
}
 
void split(node *root,vectorstring> > &inputData,vectorstring,int> > classes){
    //root->printInfo();
    root->record_number=inputData.size();
    VEC_STATI stati;
    statistic(inputData,stati);
    //printStati(stati);
    //for(int i=0;i//  cout<//cout<
    /*找到最大化GINI指标的划分*/
    double minGain=1.0;     //最小的GINI增益
    int fitIndex=-1;
    string fitCond;
    vectorstring,int> > fitleftclasses;
    vectorstring,int> > fitrightclasses;
    int fitleftnumber;
    int fitrightnumber;
    for(int i=0;i//扫描每一个自变量
        MAP_ATTR_REST::const_iterator itr=stati[i].begin();
        while(itr!=stati[i].end()){         //扫描自变量上的每一个取值
            string condition=itr->first;     //判定的条件,即到达左孩子的条件
            //cout<<"cond 为"<
            vectorstring,int> > leftclasses(classes);     //左孩子节点上类别、及对应的数目
            vectorstring,int> > rightclasses(classes);    //右孩子节点上类别、及对应的数目
            int leftnumber=0;       //左孩子节点上包含的类别数目
            int rightnumber=0;      //右孩子节点上包含的类别数目
            for(int j=0;j//更新类别对应的数目
                string rest=leftclasses[j].first;
                MAP_REST_COUNT::const_iterator iter2;
                iter2=(itr->second).find(rest);
                if(iter2==(itr->second).end()){      //没找到
                    leftclasses[j].second=0;
                    rightnumber+=rightclasses[j].second;
                }
                else{       //找到
                    leftclasses[j].second=iter2->second;
                    leftnumber+=leftclasses[j].second;
                    rightclasses[j].second-=(iter2->second);
                    rightnumber+=rightclasses[j].second;
                }
            }
            /**if(leftnumber==0 || rightnumber==0){
                cout<<"左右有一边为空"<*/
            double gain1=1.0;           //计算GINI增益
            double gain2=1.0;
            if(leftnumber==0)
                gain1=0.0;
            else
                for(int j=0;jj)        
                    gain1-=pow(1.0*leftclasses[j].second/leftnumber,2.0);
            if(rightnumber==0)
                gain2=0.0;
            else
                for(int j=0;jj)
                    gain2-=pow(1.0*rightclasses[j].second/rightnumber,2.0);
            double gain=1.0*leftnumber/(leftnumber+rightnumber)*gain1+1.0*rightnumber/(leftnumber+rightnumber)*gain2;
            //cout<<"GINI增益:"<
            if(gain<minGain){
                //cout<<"GINI增益:"<
                fitIndex=i;
                fitCond=condition;
                fitleftclasses=leftclasses;
                fitrightclasses=rightclasses;
                fitleftnumber=leftnumber;
                fitrightnumber=rightnumber;
                minGain=gain;
            }
            itr++;
        }
    }
 
    /*计算卡方值,看有没有必要进行分裂*/
    //cout<<"按"<
    int **arr=new int*[2];
    for(int i=0;i<2;i++)
        arr[i]=new int[rest_number];
    for(int i=0;i){
        arr[0][i]=fitleftclasses[i].second;
        arr[1][i]=fitrightclasses[i].second;
    }
    double chi=cal_chi(arr,2,rest_number);
    //cout<<"chi="<if(chi2]){      //独立,没必要再分裂了
        delete []arr[0];    delete []arr[1];    delete []arr;
        return;             //不需要分裂函数就返回
    }
    delete []arr[0];    delete []arr[1];    delete []arr;
     
    /*分裂*/
    root->cond=X[fitIndex]+"="+fitCond;      //root的分枝条件
    //cout<<"分类条件:"<cond<
    node *travel=root;      //root及其祖先节点的size都要加1
    while(travel!=NULL){
        (travel->size)++;
        travel=travel->parent;
    }
     
    node *LChild=new node(root);        //创建左右孩子
    node *RChild=new node(root);
    root->leftchild=LChild;
    root->rightchild=RChild;
    int maxLcount=0;
    int maxRcount=0;
    string Ldicision,Rdicision;
    for(int i=0;i//统计哪种类别出现的最多,从而作出类别判定
        if(fitleftclasses[i].second>maxLcount){
            maxLcount=fitleftclasses[i].second;
            Ldicision=fitleftclasses[i].first;
        }
        if(fitrightclasses[i].second>maxRcount){
            maxRcount=fitrightclasses[i].second;
            Rdicision=fitrightclasses[i].first;
        }
    }
    LChild->decision=Ldicision;
    RChild->decision=Rdicision;
    LChild->precision=1.0*maxLcount/fitleftnumber;
    RChild->precision=1.0*maxRcount/fitrightnumber;
     
    /*递归对左右孩子进行分裂*/
    vectorstring> > LinputData,RinputData;
    splitInput(inputData,fitIndex,fitCond,LinputData,RinputData);
    //cout<<"左边inputData行数:"<//cout<<"右边inputData行数:"<
    split(LChild,LinputData,fitleftclasses);
    split(RChild,RinputData,fitrightclasses);
}
 
/*计算子树的误差代价*/
double calR2(node *root){
    if(root->leftchild==NULL)
        return (1-root->precision)*root->record_number/total_record_number;
    else
        return calR2(root->leftchild)+calR2(root->rightchild);
}
 
/*层次遍历树,给节点标上序号。同时计算alpha*/
void index(node *root,priority_queue &pq){
    int i=1;
    queue que;
    que.push(root);
    while(!que.empty()){
        node* n=que.front();
        que.pop();
        n->index=i++;
        if(n->leftchild!=NULL){
            que.push(n->leftchild);
            que.push(n->rightchild);
            //计算表面误差率的增量
            double r1=(1-n->precision)*n->record_number/total_record_number;      //节点的误差代价
            double r2=calR2(n);
            n->alpha=(r1-r2)/(n->size-1);
            pq.push(MyTriple(n->alpha,n->size,n->index));
        }
    }
}
 
/*剪枝*/
void prune(node *root,priority_queue &pq){
    MyTriple triple=pq.top();
    int i=triple.third;
    queue que;
    que.push(root);
    while(!que.empty()){
        node* n=que.front();
        que.pop();
        if(n->index==i){
            cout<<"将要剪掉"<"的左右子树"<<endl;
            n->leftchild=NULL;
            n->rightchild=NULL;
            int s=n->size-1;
            node *trav=n;
            while(trav!=NULL){
                trav->size-=s;
                trav=trav->parent;
            }
            break;
        }
        else if(n->leftchild!=NULL){
            que.push(n->leftchild);
            que.push(n->rightchild);
        }
    }
}
 
void test(string filename,node *root){
    ifstream ifs(filename.c_str());
    if(!ifs){
        cerr<<"open inputfile failed!"<<endl;
        return;
    }
    string line;
    getline(ifs,line);
    string item;
    istringstream strstm(line);     //跳过第一行
    map<string,string> independent;       //自变量,即分类的依据
    while(getline(ifs,line)){
        istringstream strstm(line);
        //strstm.str(line);
        strstm>>item;
        cout<"\t";
        for(int i=0;ii){
            strstm>>item;
            independent[X[i]]=item;
        }
        node *trav=root;
        while(trav!=NULL){
            if(trav->leftchild==NULL){
                cout<<(trav->decision)<<"\t置信度:"<<(trav->precision)<<endl;;
                break;
            }
            string cond=trav->cond;
            string::size_type pos=cond.find("=");
            string pre=cond.substr(0,pos);
            string post=cond.substr(pos+1);
            if(independent[pre]==post)
                trav=trav->leftchild;
            else
                trav=trav->rightchild;
        }
    }
    ifs.close();
}
 
int main(){
    string inputFile="animal";
    readInput(inputFile);
    VEC_STATI stati;        //最原始的统计
    statistic(inputData,stati);
     
//  for(int i=0;i//      cout<//  cout<
    node *root=new node();
    split(root,inputData,classes);      //分裂根节点
    priority_queue pq;
    index(root,pq);
    root->printTree();
    cout<<"剪枝前使用该决策树最多进行"<size-1<<"次条件判断"<<endl;
    /**
    //检验一个是不是表面误差增量最小的被剪掉了
    while(!pq.empty()){
        MyTriple triple=pq.top();
        pq.pop();
        cout<*/
    test(inputFile,root);
     
    prune(root,pq);
    cout<<"剪枝后使用该决策树最多进行"<size-1<<"次条件判断"<<endl;
    test(inputFile,root);
    return 0;
}

 

参考文献:

http://blog.csdn.net/google19890102/article/details/32329823

转载于:https://www.cnblogs.com/Wanggcong/p/4670138.html

你可能感兴趣的:(CART)