运行环境:
window10 dev-c++5.11
决策树的实现除了,关于叶结点的类别赋值作弊了以下,其它基本依照CART生成算法实现
阅读本文之前,最好对决策树有一个认知,下文不会提到具体的步骤,建议先百度一下
x1,x2,x3,x4,y
3.6216,8.6661,-2.8073,-0.44699,0
4.5459,8.1674,-2.4586,-1.4621,0
3.866,-2.6383,1.9242,0.10645,0
3.4566,9.5228,-4.0112,-3.5944,1
0.32924,-4.4552,4.5718,-0.9888,1
4.3684,9.6718,-3.9606,-3.1625,1
保存为CSV文件,基于以上数据构造决策树
因此,需要定义一个结构体
struct tree{
int index;// index
double flag;// class
float score;// gini score
double value;// value = data[x][index]
double **left_array;// left array set
double **right_array;// right array set
int left_size;// length(left_array)
int right_size;// length(right_array)
struct tree *left;// left subtree
struct tree *right;// right right subtree
};
struct gini{
int index;// index
double value;// value = data[x][index]
float score;// gini index
double **left;// left array
double **right;// right array
int left_size;// length(left_array)
int right_size;// length(right_size)
struct gini *next;// 链接下一个
};
struct node{
int data;
struct node *next;
};
Gini(D,A)=|D1||D|Gini(D1)+|D2||D|Gini(D2) G i n i ( D , A ) = | D 1 | | D | G i n i ( D 1 ) + | D 2 | | D | G i n i ( D 2 )
Gini(D1)=1−∑ki=1(|Ck||D1|)2 G i n i ( D 1 ) = 1 − ∑ i = 1 k ( | C k | | D 1 | ) 2
|Ck| | C k | 为当类别为 k k 时,集合 D1 D 1 中类别为 k k 的个数
float Gini(double **train_data, int* target_data, int row, int group_size, int col)
{
//train_data -> store CSV file
//target_data = {0, 1}
//row = length(train_data)
//col = length(train_data[0])
//group_size = D1 or D2
int i, j;
int target_classes, count;
target_classes = sizeof(target_data) / sizeof(target_data[0]);
float probablity[target_classes], local_probablity, ratio, local_gini=0;
for(i=0; i0;
for(j=0; jif(int(train_data[j][col-1]) == target_data[i])
count++;
}// compute Ck
if(count == 0 || group_size == 0)
local_probablity = 0.0;// solve special condation
else{
local_probablity = float(count)/float(group_size);
}
probablity[i] = local_probablity;
}//probablity[target_classes] store Ck/D1
for(i=0; i1.0 - probablity[i]);
}// compute gini index
if(group_size == 0 || row == 0)
ratio = 0;
else
ratio = float(group_size) / float(row);
local_gini = ratio * local_gini; // gini index * D1/D
return local_gini;
}
先上一个Python版本的伪代码,这样会理解如何拆分,比文字叙述更直观
def test_split(index, value, dataset):
left, right = list(), list()
for row in dataset:
if row[index] < value:
left.append(row)
else:
right.append(row)
return left, right
float split_data(double **data, int index, double value, int row, int col, gini* linker)
{
float gini;
int target_data[2] = {0, 1};
node *left_node, *right_node;
left_node = (node *)malloc(sizeof(node));// store left index
right_node = (node *)malloc(sizeof(node));// store right index
left_node->next = NULL;
right_node->next = NULL;
int count_left_size=0, i, count_right_size=0, j, k;
for(i=1; iif(data[i][index] < value)
count_left_size++;
else
count_right_size++;
}// compute left array and right array size
printf("left size = %d\n", count_left_size);
printf("right size = %d\n", count_right_size);
double **left_array, **right_array;// store left array and right array
left_array = (double **)malloc(count_left_size * sizeof(double));
right_array = (double ** )malloc(count_right_size * sizeof(double));
for(i=0; idouble *)malloc(col * sizeof(double));
}
for(i=0; idouble *)malloc(col * sizeof(double));
}
for(i=1; iif(data[i][index] < value){
node *p;
p = (node *)malloc(sizeof(node));
p->data = i;
p->next = left_node->next;
left_node->next = p;
}// record left index
else{
node *q;
q = (node *)malloc(sizeof(node));
q->data = i;
q->next = right_node->next;
right_node->next = q;
}// record right index
}
left_node = left_node->next;
right_node = right_node->next;
i = 0;
while(left_node){
for(j=0; j data][j];
}
left_node = left_node->next;
i++;
}// evaluate left array depend on left_node
i = 0;
while(right_node){
for(j=0; j data][j];
}
right_node = right_node->next;
i++;
}// evaluate right array depend on right_node
gini = Gini(left_array, target_data, row, count_left_size, col);
gini += Gini(right_array, target_data, row, count_right_size, col);// compute complete gini
printf("gini = %f\n", gini);
linker->left = left_array;
linker->right = right_array;
linker->left_size = count_left_size;
linker->right_size = count_right_size;
return gini;
}
需要申请一个链表head
,来存储不同值所对应的基尼指数及拆分后的数据集合等等,遍历这个链表得到最小的基尼指数及其对应的数据集,存储到叶结点child
中
tree* get_least_gini(double** data, int row, int col)
{
int i, j, index, control=0;
double value;
float least_gini = 100.0, compare_gini, middle;
gini *head;
tree *child;
head = (gini *)malloc(sizeof(gini));
child = (tree *)malloc(sizeof(tree));
for(i=0; ifor(j=0; j1; j++){
gini *linker;
linker = (gini *)malloc(sizeof(gini));
index = j;
value = data[i][j];
linker->index = index;
linker->value = value;
compare_gini = split_data(data, index, value, row, col, linker);/* linker parameter store gini index , left_array, right_array etc.
Why? Because after invoke split_data function, compiler will delete data, so use linker to store addr and other*/
linker->score = compare_gini;
linker->next = head->next;
head->next = linker;
}
}
head = head->next;
i = row * (col - 1) - 1;// if don't do this, code will be cross the border.int i is after i run code, i limit it.
while(control < i){
if(head->score < least_gini){
least_gini = head->score;
child->left_array = head->left;
child->right_array = head->right;
child->index = head->index;
child->score = head->score;
child->value = head->value;
child->left_size = head->left_size;
child->right_size = head->right_size;
}
control++;
head = head->next;
}// find min gini and store in child
return child;
}
构造决策树时,关于处理root->flag
时作了以下弊。。。为了简单起见,我是这样做的,不过可以改进
tree* BuildTree(double**data, int row, int col, int depth, int width)
{
if(depth < 1)
return NULL;
if(row < width)
return NULL;
tree *root;
root = (tree *)malloc(sizeof(tree));
root = get_least_gini(data, row, col);
root->flag = root->right_array[0][col-1];
/*why root->flag = root->right_array[0][col-1]
Because when i run code, i found when depth is more 2, left_array will be empty. And i found when decision tree depth is more and more depth, right_array[0] is
enough represent parent node class
*/
root->left = BuildTree(root->left_array, root->left_size, col, depth-1, width);
root->right = BuildTree(root->right_array, root->right_size, col, depth-1, width);
}
double predict(tree* root, double* test, int col)
{
if(test[root->index] < root->value){
if(root->left_size > 1)
return predict(root->left, test, col);
else
return root->flag;
}
else{
if(root->right_size > 1)
return predict(root->right, test, col);
else
return root->flag;
}
}
#include
#include
#include
struct node{
int data;
struct node *next;
};
struct gini{
int index;
double value;
float score;
double **left;
double **right;
int left_size;
int right_size;
struct gini *next;
};
struct tree{
int index;
double flag;
float score;
double value;
double **left_array;
double **right_array;
int left_size;
int right_size;
struct tree *left;
struct tree *right;
};
void get_two_dimension(char* line, double** data, char *filename);
void print_two_dimension(double** data, int row, int col);
int get_row(char *filename);
int get_col(char *filename);
float Gini(double** train_data, int* target_data, int total_size, int group_size, int col);
float split_data(double** data, int index, double value, int row, int col, gini* linker);
tree* get_least_gini(double** data, int row, int col);
tree* BuildTree(double**data, int row, int col, int depth, int width);
void print_tree(tree* root, int depth);
double predict(tree* root, double* test, int col);
double test(tree* root, double** data, int row, int col);
int main()
{
tree *root;
root = (tree *)malloc(sizeof(tree));
char filename[] = "C:\\Users\\q\\titanic\\csvtest.csv";
char line[1024];
double **data;
int row, col;
int index = 1;
double value = 8.1674;
row = get_row(filename);
col = get_col(filename);
data = (double **)malloc(row * sizeof(int *));
for (int i = 0; i < row; ++i){
data[i] = (double *)malloc(col * sizeof(double));
}
get_two_dimension(line, data, filename);
printf("row = %d\n", row);
printf("col = %d\n", col);
int depth = 5 , width = 2;
root = BuildTree(data, row, col, depth, width);
print_tree(root, depth=0);
test(root, data, row, col);
}
void get_two_dimension(char* line, double** data, char *filename)
{
FILE* stream = fopen(filename, "r");
int i = 0;
while (fgets(line, 1024, stream))
{
int j = 0;
char *tok;
char* tmp = strdup(line);
for (tok = strtok(line, ","); tok && *tok; j++, tok = strtok(NULL, ",\n")){
data[i][j] = atof(tok);
}
i++;
free(tmp);
}
fclose(stream);
}
void print_two_dimension(double** data, int row, int col)
{
int i, j;
for(i=1; ifor(j=0; j printf("%f\t", data[i][j]);
}
printf("\n");
}
}
int get_row(char *filename)
{
char line[1024];
int i;
FILE* stream = fopen(filename, "r");
while(fgets(line, 1024, stream)){
i++;
}
fclose(stream);
return i;
}
int get_col(char *filename)
{
char line[1024];
int i = 0;
FILE* stream = fopen(filename, "r");
fgets(line, 1024, stream);
char* token = strtok(line, ",");
while(token){
token = strtok(NULL, ",");
i++;
}
fclose(stream);
return i;
}
tree* get_least_gini(double** data, int row, int col)
{
int i, j, index, control=0;
double value;
float least_gini = 100.0, compare_gini, middle;
gini *head;
tree *child;
head = (gini *)malloc(sizeof(gini));
child = (tree *)malloc(sizeof(tree));
for(i=0; ifor(j=0; j1; j++){
gini *linker;
linker = (gini *)malloc(sizeof(gini));
index = j;
value = data[i][j];
linker->index = index;
linker->value = value;
compare_gini = split_data(data, index, value, row, col, linker);
linker->score = compare_gini;
linker->next = head->next;
head->next = linker;
}
}
head = head->next;
i = row * (col - 1) - 1;
while(control < i){
if(head->score < least_gini){
least_gini = head->score;
child->left_array = head->left;
child->right_array = head->right;
child->index = head->index;
child->score = head->score;
child->value = head->value;
child->left_size = head->left_size;
child->right_size = head->right_size;
}
control++;
head = head->next;
}
return child;
}
float split_data(double **data, int index, double value, int row, int col, gini* linker)
{
float gini;
int target_data[2] = {0, 1};
node *left_node, *right_node;
left_node = (node *)malloc(sizeof(node));
right_node = (node *)malloc(sizeof(node));
left_node->next = NULL;
right_node->next = NULL;
int count_left_size=0, i, count_right_size=0, j, k;
for(i=1; iif(data[i][index] < value)
count_left_size++;
else
count_right_size++;
}
double **left_array, **right_array;
left_array = (double **)malloc(count_left_size * sizeof(double));
right_array = (double ** )malloc(count_right_size * sizeof(double));
for(i=0; idouble *)malloc(col * sizeof(double));
}
for(i=0; idouble *)malloc(col * sizeof(double));
}
for(i=1; iif(data[i][index] < value){
node *p;
p = (node *)malloc(sizeof(node));
p->data = i;
p->next = left_node->next;
left_node->next = p;
}
else{
node *q;
q = (node *)malloc(sizeof(node));
q->data = i;
q->next = right_node->next;
right_node->next = q;
}
}
left_node = left_node->next;
right_node = right_node->next;
i = 0;
while(left_node){
for(j=0; j data][j];
}
left_node = left_node->next;
i++;
}
i = 0;
while(right_node){
for(j=0; j data][j];
}
right_node = right_node->next;
i++;
}
gini = Gini(left_array, target_data, row, count_left_size, col);
gini += Gini(right_array, target_data, row, count_right_size, col);
linker->left = left_array;
linker->right = right_array;
linker->left_size = count_left_size;
linker->right_size = count_right_size;
return gini;
}
float Gini(double **train_data, int* target_data, int row, int group_size, int col)
{
int i, j;
int target_classes, count;
target_classes = sizeof(target_data) / sizeof(target_data[0]);
float probablity[target_classes], local_probablity, ratio, local_gini=0;
for(i=0; i0;
for(j=0; jif(int(train_data[j][col-1]) == target_data[i])
count++;
}
if(count == 0 || group_size == 0)
local_probablity = 0.0;
else{
local_probablity = float(count)/float(group_size);
}
probablity[i] = local_probablity;
}
for(i=0; i1.0 - probablity[i]);
}
if(group_size == 0 || row == 0)
ratio = 0;
else
ratio = float(group_size) / float(row);
local_gini = ratio * local_gini;
return local_gini;
}
tree* BuildTree(double**data, int row, int col, int depth, int width)
{
if(depth < 1)
return NULL;
if(row < width)
return NULL;
tree *root;
root = (tree *)malloc(sizeof(tree));
root = get_least_gini(data, row, col);
root->flag = root->right_array[0][col-1];
root->left = BuildTree(root->left_array, root->left_size, col, depth-1, width);
root->right = BuildTree(root->right_array, root->right_size, col, depth-1, width);
}
void print_tree(tree* root, int depth)
{
if(!root)
return ;
for(int i=0; iprintf(" ");
}
printf("x = %f\n", root->flag);
print_tree(root->left, depth+1);
print_tree(root->right, depth+1);
}
double predict(tree* root, double* test, int col)
{
if(test[root->index] < root->value){
if(root->left_size > 1)
return predict(root->left, test, col);
else
return root->flag;
}
else{
if(root->right_size > 1)
return predict(root->right, test, col);
else
return root->flag;
}
}
double test(tree* root, double** data, int row, int col)
{
int i, count=0;
double predict_;
printf("count = %d\n", count);
for(i=1; iprintf("predict = %f\n", predict_);
printf("data = %f\n", data[i][col-1]);
if(predict_ == data[i][col-1])
count++;
}
printf("row = %d\n", row - 1);
printf("count = %d\n", count);
printf("accuracy = %f\n", float(count) / float(row - 1));
}
输出:
row = 7
col = 5
x = 1.000000
x = 0.000000
x = 0.000000
x = 1.000000
count = 0
predict = 0.000000
data = 0.000000
predict = 0.000000
data = 0.000000
predict = 0.000000
data = 0.000000
predict = 1.000000
data = 1.000000
predict = 0.000000
data = 1.000000
predict = 1.000000
data = 1.000000
row = 6
count = 5
accuracy = 0.833333
对于训练集来说这个正确率有点低了,不过写这棵只是为了学习使用,因此用了一些作弊手段,比如对于叶结点所代表的类别
参考文献:
[1] 李航. 统计学习方法. 清华大学出版社, 2012
[2] How To Implement The Decision Tree Algorithm From Scratch In Python H o w T o I m p l e m e n t T h e D e c i s i o n T r e e A l g o r i t h m F r o m S c r a t c h I n P y t h o n