

问题描述:这里有17个属性,背景是美国选举,然后我们需要做的就是根据除了Class Name的16个属性判断这个人是哪个党派。

std::string temp[17] = { "Class Name", "handicapped-infants", "water-project-cost-sharing",
"adoption-of-the-budget-resolution", "physician-fee-freeze",
"el-salvador-aid", "religious-groups-in-schools", "anti-satellite-test-ban",
"aid-to-nicaraguan-contras", "mx-missile", "immigration", "synfuels-corporation-cutback",
"education-spending", "superfund-right-to-use", "crime", "duty-free-exports",
"export-administration-act-south-africa" };


using namespace std;
#define MAXLEN 17//输入每行的数据个数


vector  > state;    //训练实例
vector > teststate;  //测试实例
vector  item(MAXLEN);//对应一行实例集
vector  attribute_row;//保存首行即属性行数据
string blank("");
map > map_attribute_values;//存储属性对应的所有的值
int tree_size = 0;
struct Node{//决策树节点
	string attribute;//属性值
	string arrived_value;//到达的属性值
	vector childs;//所有的孩子
		attribute = blank;
		arrived_value = blank;
Node * root;

void ComputeMap(){
	unsigned int i, j, k;
	bool exited = false;
	vector values;
	for (i = 0; i < MAXLEN ; i++){//按照列遍历
		for (j = 0; j < state.size(); j++){
			for (k = 0; k < values.size(); k++){
				if (!values[k].compare(state[j][i])) exited = true;
			if (!exited){
			exited = false;
		map_attribute_values[attribute_row[i]] = values;
		values.erase(values.begin(), values.end());

double ComputeEntropy(vector  > remain_state, string attribute, string value, bool ifparent){
	//vector count(2, 0);
	int count[2] = { 0 };
	unsigned int i, j;
	bool done_flag = false;//哨兵值
	for (j = 1; j < MAXLEN; j++){
		if (done_flag) break;
		if (!attribute_row[j].compare(attribute)){
			for (i = 1; i < remain_state.size(); i++){
				if ((!ifparent&&!remain_state[i][j].compare(value)) || ifparent){//ifparent记录是否算父节点
					if (!remain_state[i][0].compare("republican")){
					else count[1]++;
			done_flag = true;
	if (count[0] == 0 || count[1] == 0) return 0;//全部是正实例或者负实例
	//具体计算熵 根据[+count[0],-count[1]],log2为底通过换底公式换成自然数底数
	double sum = count[0] + count[1];
	double entropy = -count[0] / sum*log(count[0] / sum) / log(2.0) - count[1] / sum*log(count[1] / sum) / log(2.0);
	return entropy;

double ComputeGain(vector  > remain_state, string attribute){
	unsigned int j, k, m;
	double parent_entropy = ComputeEntropy(remain_state, attribute, blank, true);
	double children_entropy = 0;
	vector values = map_attribute_values[attribute];
	vector ratio;
	vector count_values;
	int tempint;
	for (m = 0; m < values.size(); m++){
		tempint = 0;
		for (k = 1; k < MAXLEN ; k++){
			if (!attribute_row[k].compare(attribute)){
				for (j = 0; j < remain_state.size(); j++){
					if (!remain_state[j][k].compare(values[m])){

	for (j = 0; j < values.size(); j++){
		ratio.push_back((double)count_values[j] / (double)(remain_state.size() - 1));
	double temp_entropy;
	for (j = 0; j < values.size(); j++){
		temp_entropy = ComputeEntropy(remain_state, attribute, values[j], false);
		children_entropy += ratio[j] * temp_entropy;
	return (parent_entropy - children_entropy);

int FindAttriNumByName(string attri){
	for (int i = 0; i < MAXLEN; i++){
		if (attribute_row[i]==attri) return i;
	cerr << "can't find the numth of attribute" << endl;
	return 0;

string MostCommonLabel(vector  > remain_state){
	int p = 0, n = 0;
	for (unsigned i = 0; i < remain_state.size(); i++){
		if (!remain_state[i][0].compare("republican")) p++;
		else n++;
	if (p >= n) return "republican";
	else return "democrat";

bool AllTheSameLabel(vector  > remain_state, string label){
	int count = 0;
	bool mark = false;
	for (unsigned int i = 0; i < remain_state.size(); i++){
		if (!remain_state[i][0].compare(label)) count++;
	if (count == remain_state.size() - 1) return true;
	else return false;

Node * BulidDecisionTreeDFS(Node * p, vector  > remain_state, vector  remain_attribute){
	if (p == NULL)
		p = new Node();
	if (AllTheSameLabel(remain_state, "republican")){
		p->attribute = "republican";
		return p;
	if (AllTheSameLabel(remain_state, "democrat")){
		p->attribute = "democrat";
		return p;
	if (remain_attribute.size() == 0){//所有的属性均已经考虑完了,结果中还是有两个党派
		string label = MostCommonLabel(remain_state);
		p->attribute = label;
		return p;

	double max_gain = 0, temp_gain;
	vector ::iterator max_it = remain_attribute.begin();
	vector ::iterator it1;
	for (it1 = remain_attribute.begin(); it1 < remain_attribute.end(); it1++){
		temp_gain = ComputeGain(remain_state, (*it1));
		if (temp_gain > max_gain) {
			max_gain = temp_gain;
			max_it = it1;
	vector  new_attribute;
	vector  > new_state;
	for (vector ::iterator it2 = remain_attribute.begin(); it2 < remain_attribute.end(); it2++){
		if ((*it2).compare(*max_it)) new_attribute.push_back(*it2);
	p->attribute = *max_it;
	vector  values = map_attribute_values[*max_it];
	int attribue_num = FindAttriNumByName(*max_it);
	for (vector ::iterator it3 = values.begin(); it3 < values.end(); it3++){
		for (unsigned int i = 0; i < remain_state.size(); i++){
			if (!remain_state[i][attribue_num].compare(*it3)){
		Node * new_node = new Node();
		new_node->arrived_value = *it3;
		if (new_state.size() == 0){//表示当前没有这个分支的样例,当前的new_node为叶子节点
			new_node->attribute = MostCommonLabel(remain_state);
			BulidDecisionTreeDFS(new_node, new_state, new_attribute);
		//递归函数返回时即回溯时需要 将新结点加入父节点孩子容器  清除new_state容器
		new_state.erase(new_state.begin() + 1, new_state.end());//注意先清空new_state中的前一个取值的样例,准备遍历下一个取值样例
	return p;

void Input(){
	std::fstream in("train.txt", std::fstream::in | std::fstream::out);
	int lines = 0;
	while (!in.eof()) {
		std::vector temp(17);
		std::string buffer;
		int begin_sign = 0;
		int len;
		in >> buffer;
		int i1 = 0;
		for (int i = 0; i <= buffer.length(); i++) {
			if (i == buffer.length() || buffer[i] == ',') {
				temp[i1] = buffer.substr(begin_sign, i - begin_sign);
				begin_sign = i + 1;
		if (temp[0] != "") {
	cout << lines << endl;
bool Judge(Node* root, vector  teststate){
	if (!root){ cout << "error tree!"; exit(0); }
	bool istrue = false;
	for (vector::iterator it = root->childs.begin(); it != root->childs.end(); it++){         //遍历子节点
		if (((*it)->attribute == "republican" || (*it)->attribute == "democrat")){                  //如果这个节点是叶子节点,即党派
			int num = FindAttriNumByName(root->attribute);
			if ((*it)->attribute == teststate[0] && (*it)->arrived_value==teststate[num]) {         
				istrue = true;
				return true;
		int sub = FindAttriNumByName(root->attribute);                                          //非叶子节点,根据到达值继续搜索
		if ((*it)->arrived_value == teststate[sub]){
			return	Judge(*it, teststate);
	return istrue;

double Inputtest(Node* root){
	std::fstream in("test.txt", std::fstream::in | std::fstream::out);
	int lines = 0;
	while (!in.eof()) {
		std::vector temp(17);
		std::string buffer;
		int begin_sign = 0;
		int len;
		in >> buffer;
		int i1 = 0;
		for (int i = 0; i <= buffer.length(); i++) {
			if (i == buffer.length() || buffer[i] == ',') {
				temp[i1] = buffer.substr(begin_sign, i - begin_sign);
				begin_sign = i + 1;
		if (temp[0] != "") {
	int count = 0;
	cout << lines << endl;
	for (int a = 0; a < lines; a++){
		if (Judge(root, teststate[a]))
	return count*1.0 / lines*1.0;

void PrintTree(Node *p, int depth){
	for (int i = 0; i < depth; i++) cout << '\t';//按照树的深度先输出tab
	if (!p->arrived_value.empty()){
		cout << p->arrived_value << endl;
		for (int i = 0; i < depth + 1; i++) cout << '\t';//按照树的深度先输出tab
	cout << p->attribute << endl;
	for (vector::iterator it = p->childs.begin(); it != p->childs.end(); it++){
		PrintTree(*it, depth + 1);

void Treesize(Node *p){
	if (p == NULL)
	for (vector::iterator it = p->childs.begin(); it != p->childs.end(); it++){

int main(){
	vector  remain_attribute;
	std::string temp[17] = { "Class Name", "handicapped-infants", "water-project-cost-sharing",
		"adoption-of-the-budget-resolution", "physician-fee-freeze",
		"el-salvador-aid", "religious-groups-in-schools", "anti-satellite-test-ban",
		"aid-to-nicaraguan-contras", "mx-missile", "immigration", "synfuels-corporation-cutback",
		"education-spending", "superfund-right-to-use", "crime", "duty-free-exports",
		"export-administration-act-south-africa" };
	for (int a = 0; a < 17; a++){
	for (int a = 1; a < 17; a++){

	vector  > remain_state;
	for (unsigned int i = 0; i < state.size(); i++){
	root = BulidDecisionTreeDFS(root, remain_state, remain_attribute);
	cout << "the decision tree is :" << endl;
	PrintTree(root, 0);
	cout << endl;
	cout << "tree_size:" << tree_size << endl;
