CART分类器

cart 算法采用二分递归回归技术,将当前的样本集分为两个子样本集,使得生成得每个非叶子节点都有两个分支。所以,算法生成得决策树是简洁得二叉树。

分类树得两个基本思想:第一个是将训练样本进行递归地划分自变量空间进行建树得想法,第二个想法是用验证数据进行剪枝。

cart进行属性分类得是用gini指标

如果我们用k,k=1,2,3……C表示类,其中C是类别集Result的因变量数目,一个节点A的GINI不纯度定义为:

其中,Pk表示观测点中属于k类得概率,当Gini(A)=0时所有样本属于同一类,当所有类在节点中以相同的概率出现时,Gini(A)最大化,此时值为(C-1)C/2。

对于分类回归树,A如果它不满足“T都属于同一类别or T中只剩下一个样本”,则此节点为非叶节点,所以尝试根据样本的每一个属性及可能的属性值,对样本的进行二元划分,假设分类后A分为B和C,其中B占A中样本的比例为p,C为q(显然p+q=1)。则杂质改变量:Gini(A) -p*Gini(B)-q*Gini(C),每次划分该值应为非负,只有这样划分才有意义,对每个属性值尝试划分的目的就是找到杂质gai变量最大的一个划分,该属性值划分子树即为最优分支。

作业源码:

#include <stdio.h>

#include <stdlib.h>

#include <string.h>

#include <math.h>

#define EPS 0.000001

/*

 1.为了节约时间,直接在c4.5的基础上将其转化为cart算法。

 2.使用gini指标评判杂质量。

 3.非连续变量选择分割使用异化程度最小的对应属性值划分,即x和非x

 4.连续变量使用和c4.5划分方法一样得划分,但是是使用gini指标进行划分

 */

typedef struct Tuple

{

	int i;

    int g;

    double h;

    int c;

}tuple;

typedef struct TNode{

    double  gap;

    int attri;

    int reachValue;

    struct TNode *child[50];

    int kind;

}node;

tuple trainData[100];

double cal_entropy(tuple *data,int len);

double choose_best_gap(tuple *data,int len);

double cal_grainRatio(tuple *data,int len);

double cal_grainRatio2(tuple *data,int len,double gap);

double cal_splitInfo(tuple *data,int len);

int check_attribute(tuple *data,int len);

int choose_attribute(tuple *data,int len);

node *build_tree(tuple *data,int len,double reachValue,double gap);

void print_blank(int depth);

void traverse(node *no,int depth);

void test_data(node *root,tuple *data);

int cmp(const void *a, const void *b)

{

    tuple *a1=(tuple *)a;

    tuple *b1=(tuple *)b;

    return a1->h-b1->h>0?1:-1;

}

void copy_tuple(tuple *source,tuple *destination)

{

    destination->c=source->c;

    destination->g=source->g;

    destination->h=source->h;

	destination->i=source->i;

}

double cal_gini(tuple *data,int len)

{

    int i,j;

    double result=0.0;

    int cnt;

    for(i=0;i<3;i++)//有三类

    {

        cnt=0;

        for(j=0;j<len;j++)

        {

            if(data[j].c==i)

            {

                cnt++;

            }

        }

        result+=(cnt*1.0/len)*(cnt*1.0/len);

    }

	//printf("in cal_gini: %lf\n",result);

    return 1-result;

}

double cal_gender(tuple *data,int len)//计算性别分类的gini差值

{

    int i,j;

    double preGini=cal_gini(data,len);//计算分类前得gini数

    tuple subData[100];

    int subLen;

    double result=0.0;

    for(i=0;i<2;i++)

    {

        subLen=0;//统计某个性别得个数

        for(j=0;j<len;j++)

        {

            if(data[j].g==i)//属于某个性别

            {

                copy_tuple(&data[j],&subData[subLen++]);//存入数组当中

            }

        }

        result=result+subLen*1.0/len*cal_gini(subData,subLen);

    }

    return preGini-result;

}

double cal_height(tuple *data, int len,int *at)//计算性别分类的gini差值

{

    int i,j;

    double preGini=cal_gini(data,len);

	//printf("preGini: %lf\n",preGini);

	//getchar();

    tuple small[100],big[100];

    int smallLen,bigLen;

    double maxv=-1;

    for(i=0;i<len;i++)//寻找最大得gini差值得各个测试单元

    {

		smallLen=0;bigLen=0;

        for(j=0;j<len;j++)

        {

            if(data[j].h<=data[i].h)

            {

                copy_tuple(&data[j],&small[smallLen++]);

            }

            else

            {

                copy_tuple(&data[j],&big[bigLen++]);

            }

        }

		//printf("i: %d\n",i);

		//printf("smallLen: %d\n",smallLen);

		//printf("bigLen: %d\n",bigLen);

        double smallGini=cal_gini(small,smallLen);

		//printf("smallGini: %lf\n",smallGini);

        double bigGini=cal_gini(big,bigLen);

		//printf("bigGini: %lf\n",bigGini);

        double temp=preGini-(smallLen*1.0/len*smallGini+bigLen*1.0/len*bigGini);

		//printf("temp: %lf\n",temp);

        if(temp>maxv)

        {

            maxv=temp;

            *at=i;

			//printf("at: %d data[at]: %lf\n",*at,data[*at].h);

        }

    }

	//printf("maxv: %lf\n",maxv);

    return maxv;

}

int main()

{

    FILE *fp;

    fp=fopen("./data.txt", "r");

    if(fp==NULL)

    {

        printf("can not open the file: data.txt\n");

        return 0;

    }

    char name[50];

    double height;

    char gender[10];

    char kind[10];

    int i=0;

    while(fscanf(fp, "%s",name)!=EOF)

    {

		trainData[i].i=i;

        fscanf(fp,"%s",gender);

        if(!strcmp(gender, "M"))

        {

            trainData[i].g=0;

        }

        else trainData[i].g=1;

        fscanf(fp,"%lf",&height);

        trainData[i].h=height;

        fscanf(fp,"%s",kind);

        if(!strcmp(kind, "Short"))

        {

            trainData[i].c=0;

        }

        else if(!strcmp(kind,"Medium"))

        {

            trainData[i].c=1;

        }

        else{

            trainData[i].c=2;

        }

        i++;

    }

    int rows=i;

	node *root=build_tree(trainData,rows,-1,-1);

	 traverse(root,0);printf("\n");

	 fp=fopen("./testData.txt", "r");

	     if(fp==NULL)

	     {

	         printf("can not open the file!\n");

	         return 0;

	     }

	     tuple testData;

	     fscanf(fp, "%s",name);

	     fscanf(fp,"%s",gender);

	     if(!strcmp(gender, "M"))

	     {

	         testData.g=0;

	     }

	     else  testData.g=1;

	     fscanf(fp,"%lf",&height);

	      testData.h=height;

	   //  printf("testData: gender: %d\theight: %lf\n",testData.g,testData.h);

		 fclose(fp);

		 fp=NULL;

		 test_data(root,&testData);

}

void test_data(node *root,tuple *data)

{

	/*

     1.检查节点得属性值

     2.如果是身高则检查gap得值如果<=就往左,否则就往右

     3.如果是性别就判断reachValue的值

     */

    if(root->attri==-1)

    {

        printf("the test data belongs to:");

        switch (root->kind) {

            case 0: printf("Short\n");break;

            case 1: printf("Medium\n");break;

            case 2: printf("Tall\n");break;

            default:break;

        }

		return;

    }

	if(root->attri==0)

    {

        if(data->g==0)

        {

            test_data(root->child[0],data);

        }

        else

        {

            test_data(root->child[1], data);

        }

    }

    else

    {

		//printf("gap: %lf\n",root->gap);

        if(data->h<=root->gap)

        {

            test_data(root->child[0], data);

        }

        else{

            test_data(root->child[1], data);

        }

    }

}



void print_blank(int depth)

{

    int i;

    for(i=0;i<depth;i++)

    {

        printf("\t");

    }

}

void traverse(node *no,int depth)

{

    if(no==NULL)return;

    int i;

	printf("-------------------\n");

	print_blank(depth);

    printf("attri: %d\n",no->attri);print_blank(depth);

    printf("gap: %lf\n",no->gap);print_blank(depth);

    printf("kind: %d\n",no->kind);print_blank(depth);

    printf("reachValue: %d\n",no->reachValue);print_blank(depth);

	printf("-------------------\n");print_blank(depth);

    for(i=0;no->child[i]!=NULL;i++)

    {

        traverse(no->child[i], depth+1);

    }

}

int choose_attribute(tuple *data,int len)//选择属性函数,返回代表属性的代号

{

    int i;

    /*

     1.如果是性别,就直接计算增益

     2.如果是身高就计算最高得增益值得gap

     3.性别和身高得增益进行比较的到最佳得分类属性

     */

    double genderGini=cal_gender(data, len);

    int heightChoice;

    double heightGini=cal_height(data,len,&heightChoice);

    if(genderGini<heightGini)

    {

        return 1;

    }

    else

    {

        return 0;

    }

	//printf("gGrainRatio: %lf\n",gGrainRatio);

    /*计算连续属性值的增益

     1.排序确定gap

     2.计算各个gap的信息增益率

     3.选定最大得信息增益率确定该属性的最大信息增益率

     */

}

node *build_tree(tuple *data,int len,double reachValue,double gap)

{

	//getchar();getchar();

    int i,j;

	/*for(i=0;i<len;i++)

	{

		printf("data i: %d g:%d h:%lf c:%d\n",data[i].i,data[i].g,data[i].h,data[i].c);

	}*/

    int kind=check_attribute(data, len);//检查所有得元组是否属于同一个类

	//printf("kind: %d\n",kind);

    if(kind!=0)//如果所有得元组都属于同一类则作为叶子节点返回

    {

	//	printf("leaves constructed completed!\n");

        node *newNode=(node *)malloc(sizeof(node));

        newNode->gap=-1;//如果是按照身高分类就用得到gap;

        newNode->attri=-1;

        newNode->reachValue=reachValue;

        newNode->kind=kind-1;

        for(i=0;i<50;i++)newNode->child[i]=NULL;//初始化所有的孩子节点

        return newNode;

    }

    //从元组中选择最优属性值进行分类

    int attribute=choose_attribute(data, len);

	//printf("choose: %d\n",attribute);

    //执行分类 深度优先构建树结构

    node *newNode=(node *)malloc(sizeof(node));

    newNode->reachValue=reachValue;

    newNode->attri=attribute;

    newNode->kind=-1;

	newNode->gap=gap;

    for(i=0;i<50;i++)newNode->child[i]=NULL;

    if(attribute==0)//选择性别进行构建

    {

        for(i=0;i<2;i++)

        {

            tuple subData[100];

            int sublen=0;

            for(j=0;j<len;j++)

            {

                if(data[j].g==i/*是男的或者女的*/)

                {

                    copy_tuple(&data[j],&subData[sublen++]);

                }

            }

			if(sublen==0)continue;

            newNode->child[i]=build_tree(subData,sublen,i,-1);//因为是用性别构建得,所以不用gap分区间取值

        }

    }

    else

    {

        //选择高度构建

        /*

         1.选择最优得分割值

         2.将元组分割成left和right两个部分

         */

        int index=0;

        double heightGini=cal_height(data,len,&index);

        double gap=data[index].h;//选择分割连续变量得值

		newNode->gap=gap;

		//printf("best gap: %lf\n",gap);

        tuple leftData[100],rightData[100];//分割完成后,放入左右两个数组里面

        int leftlen=0;//左右数组的长度

        int rightlen=0;

        for(i=0;i<len;i++)

        {

            if(data[i].h<=gap)

            {

                copy_tuple(&data[i],&leftData[leftlen++]);

            }

            else{

                copy_tuple(&data[i],&rightData[rightlen++]);

            }

        }

		if(leftlen!=0)

        newNode->child[0]=build_tree(leftData,leftlen,-1,gap);//使用身高构建子树,因此必须分区间进行

		if(rightlen!=0)

        newNode->child[1]=build_tree(rightData,rightlen,-1,gap);

    }

    return newNode;

}

int check_attribute(tuple *data,int len)//检查所有得元组是否都是一类

{

    /*

     1.扫描所有得元组,如果出现不适同一类得元组,则返回

     */

    int i;

    for(i=1;i<len;i++)

    {

        if(data[i].c!=data[i-1].c)return 0;

    }

    return data[0].c+1;

}

 

你可能感兴趣的:(ca)