用C++来写一棵决策树

运行环境:
window10 dev-c++5.11
决策树的实现除了,关于叶结点的类别赋值作弊了以下,其它基本依照CART生成算法实现
阅读本文之前,最好对决策树有一个认知,下文不会提到具体的步骤,建议先百度一下

训练数据集

x1,x2,x3,x4,y
3.6216,8.6661,-2.8073,-0.44699,0
4.5459,8.1674,-2.4586,-1.4621,0
3.866,-2.6383,1.9242,0.10645,0
3.4566,9.5228,-4.0112,-3.5944,1
0.32924,-4.4552,4.5718,-0.9888,1
4.3684,9.6718,-3.9606,-3.1625,1

保存为CSV文件,基于以上数据构造决策树

决策树为一棵二叉树

因此,需要定义一个结构体

struct tree{
    int index;// index
    double flag;// class
    float score;// gini score
    double value;// value = data[x][index]
    double **left_array;// left array set
    double **right_array;// right array set
    int left_size;// length(left_array)
    int right_size;// length(right_array)
    struct tree *left;// left subtree
    struct tree *right;// right right subtree
};

数据集拆分,在构造决策树时需要对数据集进行拆分,定义一个结构体

struct gini{
    int index;// index
    double value;// value = data[x][index]
    float score;// gini index
    double **left;// left array
    double **right;// right array
    int left_size;// length(left_array)
    int right_size;// length(right_size)
    struct gini *next;// 链接下一个
};

数据集拆分时,需要一个单链表来记录满足条件的数组下标

struct node{
    int data;
    struct node *next;
};

基尼指数的计算问题

 Gini(D,A)=|D1||D|Gini(D1)+|D2||D|Gini(D2)   G i n i ( D , A ) = | D 1 | | D | G i n i ( D 1 ) + | D 2 | | D | G i n i ( D 2 )
 Gini(D1)=1ki=1(|Ck||D1|)2   G i n i ( D 1 ) = 1 − ∑ i = 1 k ( | C k | | D 1 | ) 2
 |Ck|   | C k | 为当类别为  k   k 时,集合  D1   D 1 中类别为  k   k 的个数

float Gini(double **train_data, int* target_data, int row, int group_size, int col)
{
    //train_data -> store CSV file
    //target_data = {0, 1} 
    //row = length(train_data)
    //col = length(train_data[0])
    //group_size = D1 or D2
    int i, j;
    int target_classes, count;
    target_classes = sizeof(target_data) / sizeof(target_data[0]);
    float probablity[target_classes], local_probablity, ratio, local_gini=0;
    for(i=0; i0;
        for(j=0; jif(int(train_data[j][col-1]) == target_data[i])
                count++;
        }// compute Ck
        if(count == 0 || group_size == 0)
            local_probablity = 0.0;// solve special condation
        else{
            local_probablity = float(count)/float(group_size);
        }
        probablity[i] = local_probablity;
    }//probablity[target_classes] store Ck/D1
    for(i=0; i1.0 - probablity[i]);
    }// compute gini index
    if(group_size == 0 || row == 0)
        ratio = 0;
    else
        ratio = float(group_size) / float(row); 
    local_gini = ratio * local_gini; // gini index * D1/D
    return  local_gini;
}

根据基尼指数拆分数据集

先上一个Python版本的伪代码,这样会理解如何拆分,比文字叙述更直观

def test_split(index, value, dataset):
    left, right = list(), list()
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right
float split_data(double **data, int index, double value, int row, int col, gini* linker)
{
    float gini;
    int target_data[2] = {0, 1};
    node *left_node, *right_node;
    left_node = (node *)malloc(sizeof(node));// store left index 
    right_node = (node *)malloc(sizeof(node));// store right index
    left_node->next = NULL;
    right_node->next = NULL; 
    int count_left_size=0, i, count_right_size=0, j, k;
    for(i=1; iif(data[i][index] < value)
            count_left_size++;
        else
            count_right_size++;
    }// compute left array and right array size
    printf("left size = %d\n", count_left_size);
    printf("right size = %d\n", count_right_size);
    double **left_array, **right_array;// store left array and right array
    left_array = (double **)malloc(count_left_size * sizeof(double));
    right_array = (double ** )malloc(count_right_size * sizeof(double));
    for(i=0; idouble *)malloc(col * sizeof(double));
    }
    for(i=0; idouble *)malloc(col * sizeof(double));
    }
    for(i=1; iif(data[i][index] < value){
            node *p;
            p = (node *)malloc(sizeof(node));
            p->data = i;
            p->next = left_node->next;
            left_node->next = p;
        }// record left index
        else{
            node *q;
            q = (node *)malloc(sizeof(node));
            q->data = i;
            q->next = right_node->next;
            right_node->next = q;
        }// record right index
    }
    left_node = left_node->next;
    right_node = right_node->next;
    i = 0;
    while(left_node){
        for(j=0; jdata][j];
        } 
        left_node = left_node->next;
        i++;
    }// evaluate left array depend on left_node
    i = 0;
    while(right_node){
        for(j=0; jdata][j];
        }
        right_node = right_node->next;
        i++;
    }// evaluate right array depend on right_node
    gini = Gini(left_array, target_data, row, count_left_size, col);
    gini += Gini(right_array, target_data, row, count_right_size, col);// compute complete gini
    printf("gini = %f\n", gini);
    linker->left = left_array;
    linker->right = right_array;
    linker->left_size = count_left_size;
    linker->right_size = count_right_size;
    return gini;
}

计算一个集合中所有值的基尼指数  Gini(D,x)   G i n i ( D , x ) ,得到最小的基尼指数

需要申请一个链表head,来存储不同值所对应的基尼指数及拆分后的数据集合等等,遍历这个链表得到最小的基尼指数及其对应的数据集,存储到叶结点child

tree* get_least_gini(double** data, int row, int col)
{
    int i, j, index, control=0;
    double value;
    float least_gini = 100.0, compare_gini, middle;
    gini *head;
    tree *child;
    head = (gini *)malloc(sizeof(gini));
    child = (tree *)malloc(sizeof(tree));
    for(i=0; ifor(j=0; j1; j++){
            gini *linker;
            linker = (gini *)malloc(sizeof(gini));
            index = j;
            value = data[i][j];
            linker->index = index;
            linker->value = value;
            compare_gini = split_data(data, index, value, row, col, linker);/* linker parameter store gini index , left_array, right_array etc. 
            Why? Because after invoke split_data function, compiler will delete data, so use linker to store addr and other*/ 
            linker->score = compare_gini;
            linker->next = head->next;
            head->next = linker;
        }
    }
    head = head->next;
    i = row * (col - 1) - 1;// if don't do this, code will be cross the border.int  i is after i run code, i limit it. 
    while(control < i){
        if(head->score < least_gini){
            least_gini = head->score;
            child->left_array = head->left;
            child->right_array = head->right;
            child->index = head->index;
            child->score = head->score;
            child->value = head->value;
            child->left_size = head->left_size;
            child->right_size = head->right_size;
        }
        control++;
        head = head->next;
    }// find min gini and store in child
    return child;
}

递归地构造决策树

构造决策树时,关于处理root->flag时作了以下弊。。。为了简单起见,我是这样做的,不过可以改进

tree* BuildTree(double**data, int row, int col, int depth, int width)
{
    if(depth < 1)
        return NULL;
    if(row < width)
        return NULL;
    tree *root;
    root = (tree *)malloc(sizeof(tree));
    root = get_least_gini(data, row, col);
    root->flag = root->right_array[0][col-1];
    /*why root->flag = root->right_array[0][col-1]
    Because when i run code, i found when depth is more 2, left_array will be empty. And i found when decision tree depth is more and more depth, right_array[0] is 
    enough represent parent node class
    */
    root->left = BuildTree(root->left_array, root->left_size, col, depth-1, width);
    root->right = BuildTree(root->right_array, root->right_size, col, depth-1, width);
}

递归地进行预测预测

double predict(tree* root, double* test, int col)
{
    if(test[root->index] < root->value){
        if(root->left_size > 1)
            return predict(root->left, test, col);
        else 
            return root->flag;
    }
    else{
        if(root->right_size > 1)
            return predict(root->right, test, col);
        else
            return root->flag;
    }
}

完整代码

#include 
#include 
#include 

struct node{
    int data;
    struct node *next;
};

struct gini{
    int index;
    double value;
    float score;
    double **left;
    double **right;
    int left_size;
    int right_size;
    struct gini *next;
};

struct tree{
    int index;
    double flag;
    float score;
    double value;
    double **left_array;
    double **right_array;
    int left_size;
    int right_size;
    struct tree *left;
    struct tree *right;
};

void get_two_dimension(char* line, double** data, char *filename);
void print_two_dimension(double** data, int row, int col);
int get_row(char *filename);
int get_col(char *filename);
float Gini(double** train_data, int* target_data, int total_size, int group_size, int col);
float split_data(double** data, int index, double value, int row, int col, gini* linker);
tree* get_least_gini(double** data, int row, int col);
tree* BuildTree(double**data, int row, int col, int depth, int width); 
void print_tree(tree* root, int depth);
double predict(tree* root, double* test, int col);
double test(tree* root, double** data, int row, int col);

int main()
{
    tree *root;
    root = (tree *)malloc(sizeof(tree));
    char filename[] = "C:\\Users\\q\\titanic\\csvtest.csv";
    char line[1024];
    double **data;
    int row, col;
    int index = 1;
    double value = 8.1674;
    row = get_row(filename);
    col = get_col(filename);
    data = (double **)malloc(row * sizeof(int *));
    for (int i = 0; i < row; ++i){
        data[i] = (double *)malloc(col * sizeof(double));
    } 
    get_two_dimension(line, data, filename);
    printf("row = %d\n", row);
    printf("col = %d\n", col);
    int depth = 5 , width = 2;
    root = BuildTree(data, row, col, depth, width);
    print_tree(root, depth=0);
    test(root, data, row, col);
}

void get_two_dimension(char* line, double** data, char *filename)
{
    FILE* stream = fopen(filename, "r");
    int i = 0;
    while (fgets(line, 1024, stream)) 
    {
        int j = 0;
        char *tok;
        char* tmp = strdup(line);
        for (tok = strtok(line, ","); tok && *tok; j++, tok = strtok(NULL, ",\n")){
            data[i][j] = atof(tok); 
        } 
        i++;
        free(tmp);
    }
    fclose(stream);
}

void print_two_dimension(double** data, int row, int col)
{
    int i, j;
    for(i=1; ifor(j=0; jprintf("%f\t", data[i][j]);
        }
        printf("\n");
    }
}

int get_row(char *filename)
{
    char line[1024];
    int i;
    FILE* stream = fopen(filename, "r");
    while(fgets(line, 1024, stream)){
        i++;
    }
    fclose(stream);
    return i;
}

int get_col(char *filename)
{
    char line[1024];
    int i = 0;
    FILE* stream = fopen(filename, "r");
    fgets(line, 1024, stream);
    char* token = strtok(line, ",");
    while(token){
        token = strtok(NULL, ",");
        i++;
    }
    fclose(stream);
    return i;
}

tree* get_least_gini(double** data, int row, int col)
{
    int i, j, index, control=0;
    double value;
    float least_gini = 100.0, compare_gini, middle;
    gini *head;
    tree *child;
    head = (gini *)malloc(sizeof(gini));
    child = (tree *)malloc(sizeof(tree));
    for(i=0; ifor(j=0; j1; j++){
            gini *linker;
            linker = (gini *)malloc(sizeof(gini));
            index = j;
            value = data[i][j];
            linker->index = index;
            linker->value = value;
            compare_gini = split_data(data, index, value, row, col, linker);
            linker->score = compare_gini;
            linker->next = head->next;
            head->next = linker;
        }
    }
    head = head->next;
    i = row * (col - 1) - 1;
    while(control < i){
        if(head->score < least_gini){
            least_gini = head->score;
            child->left_array = head->left;
            child->right_array = head->right;
            child->index = head->index;
            child->score = head->score;
            child->value = head->value;
            child->left_size = head->left_size;
            child->right_size = head->right_size;
        }
        control++;
        head = head->next;
    }
    return child;
}

float split_data(double **data, int index, double value, int row, int col, gini* linker)
{
    float gini;
    int target_data[2] = {0, 1};
    node *left_node, *right_node;
    left_node = (node *)malloc(sizeof(node));
    right_node = (node *)malloc(sizeof(node));
    left_node->next = NULL;
    right_node->next = NULL; 
    int count_left_size=0, i, count_right_size=0, j, k;
    for(i=1; iif(data[i][index] < value)
            count_left_size++;
        else
            count_right_size++;
    }
    double **left_array, **right_array;
    left_array = (double **)malloc(count_left_size * sizeof(double));
    right_array = (double ** )malloc(count_right_size * sizeof(double));
    for(i=0; idouble *)malloc(col * sizeof(double));
    }
    for(i=0; idouble *)malloc(col * sizeof(double));
    }
    for(i=1; iif(data[i][index] < value){
            node *p;
            p = (node *)malloc(sizeof(node));
            p->data = i;
            p->next = left_node->next;
            left_node->next = p;
        }
        else{
            node *q;
            q = (node *)malloc(sizeof(node));
            q->data = i;
            q->next = right_node->next;
            right_node->next = q;
        }
    }
    left_node = left_node->next;
    right_node = right_node->next;
    i = 0;
    while(left_node){
        for(j=0; jdata][j];
        } 
        left_node = left_node->next;
        i++;
    }
    i = 0;
    while(right_node){
        for(j=0; jdata][j];
        }
        right_node = right_node->next;
        i++;
    }
    gini = Gini(left_array, target_data, row, count_left_size, col);
    gini += Gini(right_array, target_data, row, count_right_size, col);
    linker->left = left_array;
    linker->right = right_array;
    linker->left_size = count_left_size;
    linker->right_size = count_right_size;
    return gini;
}

float Gini(double **train_data, int* target_data, int row, int group_size, int col)
{
    int i, j;
    int target_classes, count;
    target_classes = sizeof(target_data) / sizeof(target_data[0]);
    float probablity[target_classes], local_probablity, ratio, local_gini=0;
    for(i=0; i0;
        for(j=0; jif(int(train_data[j][col-1]) == target_data[i])
                count++;
        }
        if(count == 0 || group_size == 0)
            local_probablity = 0.0;
        else{
            local_probablity = float(count)/float(group_size);
        }
        probablity[i] = local_probablity;
    }
    for(i=0; i1.0 - probablity[i]);
    }
    if(group_size == 0 || row == 0)
        ratio = 0;
    else
        ratio = float(group_size) / float(row); 
    local_gini = ratio * local_gini; 
    return  local_gini;
}

tree* BuildTree(double**data, int row, int col, int depth, int width)
{
    if(depth < 1)
        return NULL;
    if(row < width)
        return NULL;
    tree *root;
    root = (tree *)malloc(sizeof(tree));
    root = get_least_gini(data, row, col);
    root->flag = root->right_array[0][col-1];
    root->left = BuildTree(root->left_array, root->left_size, col, depth-1, width);
    root->right = BuildTree(root->right_array, root->right_size, col, depth-1, width);
}

void print_tree(tree* root, int depth)
{
    if(!root)
        return ;
    for(int i=0; iprintf("  ");
    }
    printf("x = %f\n", root->flag);
    print_tree(root->left, depth+1);
    print_tree(root->right, depth+1);
}

double predict(tree* root, double* test, int col)
{
    if(test[root->index] < root->value){
        if(root->left_size > 1)
            return predict(root->left, test, col);
        else 
            return root->flag;
    }
    else{
        if(root->right_size > 1)
            return predict(root->right, test, col);
        else
            return root->flag;
    }
}

double test(tree* root, double** data, int row, int col)
{
    int i, count=0;
    double predict_;
    printf("count = %d\n", count);
    for(i=1; iprintf("predict = %f\n", predict_);
        printf("data = %f\n", data[i][col-1]);
        if(predict_ == data[i][col-1])
            count++;
    }
    printf("row = %d\n", row - 1);
    printf("count = %d\n", count);
    printf("accuracy = %f\n", float(count) / float(row - 1));
}

输出:

row = 7
col = 5
x = 1.000000
  x = 0.000000
    x = 0.000000
  x = 1.000000
count = 0
predict = 0.000000
data = 0.000000
predict = 0.000000
data = 0.000000
predict = 0.000000
data = 0.000000
predict = 1.000000
data = 1.000000
predict = 0.000000
data = 1.000000
predict = 1.000000
data = 1.000000
row = 6
count = 5
accuracy = 0.833333

对于训练集来说这个正确率有点低了,不过写这棵只是为了学习使用,因此用了一些作弊手段,比如对于叶结点所代表的类别

参考文献:
[1] 李航. 统计学习方法. 清华大学出版社, 2012
[2]  How To Implement The Decision Tree Algorithm From Scratch In Python   H o w   T o   I m p l e m e n t   T h e   D e c i s i o n   T r e e   A l g o r i t h m   F r o m   S c r a t c h   I n   P y t h o n

你可能感兴趣的:(c语言,机器学习)