决策树刚好大一期末课设就是决策树,今天刚好搬过来,顺便复习一下。
纯纯C写的。。
操作平台:Visual Studio 2022
(搬运自周志华<<西瓜书>>。。。。)
决策树(decision tree)是一类常见的机器学习方法.以二分类任务为例,我们希望从给定训练数据集学得一个模型用以对新示例进行分类,这个把样本分类的任务,可看作对“当前样本属于正类吗?”这个问题的“决策”或“判定”过程.顾名思义,决策树是基于树结构来进行决策的,这恰是人类在面临决策问题时一种很自然的处理机制.例如,我们要对“这是好瓜吗?”这样的问题进行决策时,通常会进行一系列的判断或“子决策”:我们先看“它是什么颜色?”,如果是“青绿色”,则我们再看“它的根蒂是什么形态?”,如果是“蜷缩”,我们再判断“它敲起来是什么声音?”,最后,我们得出最终决策:这是个好瓜.这个决策过程如图:
显然,决策过程的最终结论对应了我们所希望的判定结果,例如“是”或“不是”好瓜;决策过程中提出的每个判定问题都是对某个属性的“测试”,例如“色泽=?”“根蒂=?”;每个测试的结果或是导出最终结论,或是导出进一步的判定问题,其考虑范围是在上次决策结果的限定范围之内,例如若在“色泽=青绿”之后再判断“根蒂-?”,则仅在考虑青绿色瓜的根蒂.
一般的,一棵决策树包含一个根结点、若干个内部结点和若干个叶结点;叶结点对应于决策结果,其他每个结点则对应于一个属性测试;每个结点包含的样本集合根据属性测试的结果被划分到子结点中;根结点包含样本全集.从根结点到每个叶结点的路径对应了一个判定测试序列.决策树学习的目的是为了产生一棵泛化能力强,即处理未见示例能力强的决策树,其基本流程遵循简单且直观的“分而治之”(divide-and-conquer)策略,如图4.2所示.
“信息嫡”(information entropy)是度量样本集合纯度最常用的一种指标.假定当前样本集合D中第k类样本所占的比例为pk (k = 1,2,…,|y|),则D的信息嫡定义为
Ent(D)的值越小,则D的纯度越高.值越高信息越混乱。
假定离散属性a有V个可能的取值{a1 , a2,… . ,aV},若使用a来对样本集D进行划分,则会产生V个分支结点,其中第v个分支结点包含了D中所有在属性a上取值为av的样本,记为D”.我们可根据上式计算出Dv的信息嫡,再考虑到不同的分支结点所包含的样本数不同,给分支结点赋予权重|Dv|/|D|,即样本数越多的分支结点的影响越大,于是可计算出用属性α对样本集D进行划分所获得的“信息增益”(information gain)
一般而言,信息增益越大,则意味着使用属性α来进行划分所获得的“纯度提升”越大.因此,我们可用信息增益来进行决策树的划分属性选择,即在上式算法选择属性a*= argmaxGain(D,a)、著名的ID3决策树学习算就是以信息增益为准则来选择划分属性.
默认Visual Studio 2022 都装好了,太久之前,也是临时装的,代码不多,直接粘上去的。
1.新建一个控制台应用
2.关闭SDL检查的方法是:项目->属性->C/C+±>SDL检查,选测否,就可以将其关闭(毕竟纯C嘛,都是课本里的函数)
head.h
#include
#include
#include
#include
#define LineSize 4
typedef struct map
{
int arc[20][20];/* 邻接矩阵,可看作边表 */
int numNodes;
}MGraph;
typedef struct Node
{ //定义二叉树结构
char data[10];
int judge;
struct Node* lchild, * rchild;
}*BiTree, BiTNode;
typedef struct entarray {
float entsum; /*熵*/
float connet; /*条件熵*/
float addent; /*信息增量*/
int flag; /*标注标志*/
}Entarray;
typedef struct answer {
int an;
char data[10];
}Answer;
/*导入数据创建邻接矩阵*/
void CreateMGraph(MGraph* G);
/*功能菜单*/
void menu(MGraph* G, BiTNode* T, Entarray arr[]);
/*展示邻接矩阵*/
void show(MGraph* G);
/*退出*/
void Exit();
/*选择增益信息最大的*/
int sort(Entarray arr[]);
/*创建决策树*/
BiTree createtree(MGraph* G, Entarray arr[], BiTNode* T, int dex, int dir);
/*条件熵*/
void conentropy(MGraph* G, int j, Entarray arr[]);
/*熵*/
void entropy(MGraph* G, Entarray arr[]);
/*增益信息*/
void Addnet(int dex, Entarray arr[]);
/*打印决策树*/
void printTree(BiTNode* T, int type, int level);
/*更新邻接矩阵矩阵*/
void updataG(MGraph* G, MGraph* GG, int dex, int m);
/*判断是否将该节点作为叶子节点*/
int ifleaf(MGraph G, Entarray arr[]);
/*输入数据并展示结果*/
void inputdata(BiTNode* T);
/*输出遍历路径*/
void outputdata(BiTNode* T, Answer answer[]);
test_tree.cpp
#include"head.h"
int main()
{
MGraph G;
BiTNode T;
T.lchild = T.rchild = NULL;
Entarray arr[4];
for (int i = 0; i < LineSize; i++)
{
arr[i].flag = 0; /*初始化*/
}
menu(&G, &T, arr);
return 0;
}
func.cpp
#include"head.h"
#include
char str1[] = "fail", str2[] = "hungry", str3[] = "early", str4[] = "sleepy", str5[] = "no";
char list[10][10] = { "exam","hungry","time","condition" };
void CreateMGraph(MGraph* G)
{
int i = 0, j;
FILE* fp;
fp = fopen("D:\\traindata.txt", "r");
char a[20], b[20], c[20], d[20], e[20];
while ((fscanf(fp, "%s %s %s %s %s", a, b, c, d, e)) != EOF)
{
printf("%s %s %s %s %s\n", a, b, c, d, e);
j = 0;
G->arc[i][j] = strcmp(a, str1);/* 0代表fail,1代表pass*/
j++;
G->arc[i][j] = strcmp(b, str2);/*1代表not-hungry 0代表hungry */
j++;
G->arc[i][j] = strcmp(c, str3);/* 0代表early,1代表late*/
j++;
G->arc[i][j] = strcmp(d, str4);/*1代表sober,0代表sleepy*/
j++;
G->arc[i][j] = strcmp(e, str5);
i++;
}
G->numNodes = i;
show(G);
fclose(fp);
}
void menu(MGraph* G, BiTNode* T, Entarray arr[])
{
int choice = 1;
while (choice != 0)
{
printf("--------------------------------------------------------------------------------------------------\n");
printf("\n\n 该系统将帮你决定是否打游戏\n");
printf(" ################################################\n");
printf(" 按键1:导入训练样本,显示训练数据\n");
printf(" 按键2:构建决策树,并展示熵\n");
printf(" 按键3:展示决策树\n");
printf(" 按键4: 输入一组数据,展示结果\n");
printf(" 按键0: 退出\n");
printf(" ################################################\n");
printf("--------------------------------------------------------------------------------------------------\n");
printf("\n ====>>>>>请输入你的选择:");
scanf("%d", &choice);
switch (choice)
{
case 1:CreateMGraph(G); break;
case 2:createtree(G, arr, T, 0, 0); break;/*0在函数中会更新*/
case 3:printTree(T, 0, 0); break;
case 4:inputdata(T); break;
case 0:Exit(); break;
defaul:printf("错误选择,请重新选择。\n"); break;
}
}
}
void Exit()
{
printf("\n =====>已成功退出<===== \n");
}
void show(MGraph* G)
{
int i, j;
printf("\n\n是否通过考试 饿不饿 是否很晚 是否精神 是否玩游戏\n");
for (i = 0; i < G->numNodes; i++)
{
for (j = 0; j < LineSize + 1; j++)
printf(" %d ", G->arc[i][j]);
printf("\n");
}
}
int sort(Entarray arr[])/*选择最大的*/
{
float max = -1;
int dex;
for (int i = 0; i < LineSize; i++)
{
if (arr[i].flag == 0 && arr[i].addent > max)
{
max = arr[i].addent;
dex = i;
}
}
if (max < 0)
return -1;
else
return dex;
}
BiTree createtree(MGraph* G, Entarray arr[], BiTNode* T, int dex, int dir) /*创建决策树 0*/
{
int count = 0;
for (int i = 0; i < LineSize; i++)
{
if (arr[i].flag == 1)/*flag为0时表示都没有使用过*/
count++;
}
if (count == 0)/*判断是否为根节点*/
{
entropy(G, arr); /*信息增益最大的*/
dex = sort(arr);
printf("\n以下为第%d次计算增益熵,熵,条件熵\n", count);
for (int k = 0; k < 4; k++)
printf("%s: %f %f %f\n", list[k], arr[k].addent, arr[k].entsum, arr[k].connet);
printf("@@@@@@@@@增益熵最大位置:dex=%d\n\n", dex);
strcpy(T->data, list[dex]);//已完成根节点赋值
arr[dex].flag = 1;/*标注为1,记为已使用过*/
T->lchild = createtree(G, arr, T->lchild, dex, 0);/*dex=3*/
T->rchild = createtree(G, arr, T->rchild, dex, 1);
}
else if (count > 0 && count < 4)/*不为根节点*/
{
int out;
Entarray brr0[LineSize], brr1[LineSize];
for (int j = 0; j < 4; j++)/*将arr赋给brr $*/
{
brr1[j].addent = brr0[j].addent = arr[j].addent;
brr1[j].connet = brr0[j].connet = arr[j].connet;
brr1[j].entsum = brr0[j].entsum = arr[j].entsum;
brr1[j].flag = brr0[j].flag = arr[j].flag;
}
MGraph G0, G1;
int dex0, dex1;
if (dir == 0)
{
updataG(G, &G0, dex, 0);/*0否左,第一趟时dex=3*/
out = ifleaf(G0, brr0);
printf("左leaf:%d\n----------\n", out);/*-1*/
if (out >= 0)/*判断是否建立叶节点*/
{
BiTNode* t;
t = (BiTNode*)malloc(sizeof(BiTNode));
t->judge = out;
t->lchild = t->rchild = NULL;
return t;
}
entropy(&G0, brr0);/*这里brr会变化,且必须在update之后*/
dex0 = sort(brr0);/*sort要根据arr数组里的值来确定下一个dex*/
printf("\n这是G0:\n");
show(&G0);
printf("\n以下为第%d次计算增益熵,熵,条件熵\n", count);
for (int k = 0; k < LineSize; k++)
printf("%s: %f %f %f\n", list[k], brr0[k].addent, brr0[k].entsum, brr0[k].connet);
printf("@@@@@@@@@@增益熵最大位置:dex=%d\n\n", dex0);
if (dex0 == -1)
return NULL;
else
{
BiTNode* t;
t = (BiTNode*)malloc(sizeof(BiTNode));
strcpy(t->data, list[dex0]);
brr0[dex0].flag = 1;
t->lchild = createtree(&G0, brr0, t->lchild, dex0, 0);/*dex=1*/
t->rchild = createtree(&G0, brr0, t->rchild, dex0, 1);
return t;
}
}
else
{
updataG(G, &G1, dex, 1);/*1是右*/
out = ifleaf(G1, brr1);
printf("右leaf:%d\n----------\n", out);
if (out >= 0)
{
BiTNode* t;
t = (BiTNode*)malloc(sizeof(BiTNode));
t->judge = out;
t->lchild = t->rchild = NULL;
return t;
}
entropy(&G1, brr1);
dex1 = sort(brr1);
printf("\n这是G1:\n");
show(&G1);
printf("以下为第%d次计算增益熵,熵,条件熵\n", count);
for (int k = 0; k < LineSize; k++)
printf("%s: %f %f %f\n", list[k], brr1[k].addent, brr1[k].entsum, brr1[k].connet);
printf("增益熵最大位置:dex=%d\n\n", dex1);
if (dex1 == -1)
return NULL;
else
{
BiTNode* t;
t = (BiTNode*)malloc(sizeof(BiTNode));
strcpy(t->data, list[dex1]);
brr1[dex1].flag = 1;
t->lchild = createtree(&G1, brr1, t->lchild, dex1, 0);
t->rchild = createtree(&G1, brr1, t->rchild, dex1, 1);
return t;
}
}
}
else
{
return NULL;
}
}
void entropy(MGraph* G, Entarray arr[]) /*熵*/
{
int i, j;
float ent, sum = 0, num1, num0;
for (j = 0; j < LineSize; j++)
{
num1 = 0;
num0 = 0;
for (i = 0; i < G->numNodes; i++)
{
if (G->arc[i][LineSize] == 1)
num1++;
if (G->arc[i][LineSize] == 0)
num0++;
}
sum = num1 + num0;
if ((num1 / sum) > 0 && (num1 / sum) < 1)
ent = -(num1 / sum) * log2(num1 / sum) - (num0 / sum) * log2(num0 / sum);
else
ent = 0;
arr[j].entsum = ent;
conentropy(G, j, arr);
Addnet(j, arr);
}
}
void conentropy(MGraph* G, int j, Entarray arr[]) /*条件熵,dex只表示第几列的条件熵*/
{
int i;
float ent0, ent1, sum0, sum1, Ent;
float num01, num10, num00, num11;
num00 = 0; num01 = 0;
num10 = 0; num11 = 0;
for (i = 0; i < G->numNodes; i++)
{
if (G->arc[i][j] == 0 && G->arc[i][LineSize] == 0)/*否的条件熵*/
num00++;/*5*/
else if (G->arc[i][j] == 0 && G->arc[i][LineSize] == 1)
num01++;/*1*/
else if (G->arc[i][j] == 1 && G->arc[i][LineSize] == 0)/*是的条件熵*/
num10++;
else
num11++;
}
sum0 = num00 + num01; /*6*/
sum1 = num10 + num11;
if ((num00 / (sum0 + 0.001)) == 0 || (num00 + 0.001) / (sum0 + 0.001) == 1) /*加0.001防止分母为零*/
ent0 = 0;
else
{
ent0 = -(num00 / sum0) * log2(num00 / sum0) - (num01 / sum0) * log2(num01 / sum0);
}
if ((num10 / (sum1 + 0.001)) == 0 || (0.001 + num10) / (sum1 + 0.001) == 1)
ent1 = 0;
else
{
ent1 = -(num10 / sum1) * log2(num10 / sum1) - (num11 / sum1) * log2(num11 / sum1);
}
Ent = (sum0 / (sum0 + sum1)) * ent0 + (sum1 / (sum0 + sum1)) * ent1;
arr[j].connet = Ent;
}
void Addnet(int dex, Entarray arr[])
{
arr[dex].addent = arr[dex].entsum - arr[dex].connet;
}
void printTree(BiTNode* T, int type, int level)/*type : 0表示根节点,1表示左节点,2表示右节点. level表示层次,用于控制显示的距离*/
{
int i;
if (NULL == T)
return;
printTree(T->rchild, 2, level + 1);
switch (type)
{
case 0:
if (T->lchild == NULL && T->rchild == NULL)
printf(" %d\n", T->judge);
else
{
printf(" %s\n", T->data);
}
break;
case 1:
for (i = 0; i < level; i++)
printf("\t");
printf("\\\n");
for (i = 0; i < level; i++)
printf("\t");
if (T->lchild == NULL && T->rchild == NULL)
printf(" %d\n", T->judge);
else
{
printf(" %s\n", T->data);
}
break;
case 2:
for (i = 0; i < level; i++)
printf("\t");
if (T->lchild == NULL && T->rchild == NULL)
printf(" %d\n", T->judge);
else
{
printf(" %s\n", T->data);
}
for (i = 0; i < level; i++)
printf("\t");
printf("/\n");
break;
}
printTree(T->lchild, 1, level + 1);
}
//更新邻接矩阵
void updataG(MGraph* G, MGraph* g, int dex, int m)/*dex代表目标列,m在01之间变化*/
{
g->numNodes = 0;
int i, j, k = 0;
for (i = 0; i < G->numNodes; i++)
{
if (G->arc[i][dex] == m) //匹配赋值
{
for (j = 0; j < LineSize + 1; j++)
{
g->arc[k][j] = G->arc[i][j];
}
g->numNodes++;
k++;
}
}
}
int ifleaf(MGraph G, Entarray arr[])/*判断是否建立叶子节点,G已经为更新的矩阵*/ /*如果G空怎么办*/
{
int i, count = 0;
float ent, sum = 0, num1, num0;
num1 = 0;
num0 = 0;
for (i = 0; i < G.numNodes; i++)
{
if (G.arc[i][LineSize] == 1)
num1++;
if (G.arc[i][LineSize] == 0)
num0++;
}
sum = num1 + num0;
if ((num1 / sum) > 0 && (num1 / sum) < 1)
ent = -(num1 / sum) * log2(num1 / sum) - (num0 / sum) * log2(num0 / sum);
else
ent = 0;
for (int j = 0; j < LineSize; j++)
{
if (arr[j].flag == 1)
count++;
}
printf("子函数中熵leaf:%f,数据数量:%f,count:%d\n", ent, sum, count);
if (ent == 0)
{ /*熵为0可以建立叶子节点*/
if (num1 > 0)
return 1;
else
return 0;
}
else if (sum < 1 || count == 3)/*多少诀*/
{
if (num1 >= num0)
return 1;
else
return 0;
}
else /*-1时不建立叶子节点*/
return -1;
}
void inputdata(BiTNode* T)
{
Answer answer[LineSize];
printf("请回答以下问题:\n");
int a = 0;
while (a < 4)
{
printf("Q0:期末考过了吗?\n[0]没过 [1]过了 [其他]退出\nanswer:");
scanf("%d", &answer[a].an);
strcpy(answer[a].data, list[a]);
if (answer[a].an > 1 || answer[a].an < 0)
break;
a++;
printf("\nQ1:现在饿不饿?\n[0]饿 [1]不饿 [其他]退出\nanswer:");
scanf("%d", &answer[a].an);
strcpy(answer[a].data, list[a]);
if (answer[a].an > 1 || answer[a].an < 0)
break;
a++;
printf("\nQ2:现在时间是否很晚了?\n[0]晚 [1]还早 [其他]退出\nanswer:");
scanf("%d", &answer[a].an);
strcpy(answer[a].data, list[a]);
if (answer[a].an > 1 || answer[a].an < 0)
break;
a++;
printf("\nQ3:状态怎么样?\n[0]想睡觉 [1]精神十足 [其他]退出\nanswer:");
scanf("%d", &answer[a].an);
strcpy(answer[a].data, list[a]);
if (answer[a].an > 1 || answer[a].an < 0)
break;
a++;
}
if (a == 4)
{
outputdata(T, answer);
}
else
printf("====>退出成功<====\n");
}
void outputdata(BiTNode* T, Answer answer[])
{
BiTNode* p = T;
while (p != NULL)
{
for (int i = 0; i < LineSize; i++)
{
if (strcmp(p->data, answer[i].data) == 0)
{
printf("%s-%d-> ", p->data, answer[i].an);
if (answer[i].an == 0)
p = p->lchild;
else
p = p->rchild;
}
}
if (p->lchild == NULL && p->rchild == NULL)
{
printf("决策为:<%d> (0为否,1为是)\n", p->judge);
break;
}
}
}
大一时写的代码,不知道代码有没有错误,还望指正。
附代码和数据集:链接: link
参考文献:
周志华<<机器学习>>