CART分类回归树主要是通过灵活的分类方式,而不是死板的维度,对样本集进行划分,使用基尼系数计算分类后的信息增益,然后找到最大的增益方向并进行分类。代码如下:
#ifndef _CART_T_HPP_
#define _CART_T_HPP_
#include "decision_tree.hpp"
#include
#include
double gini(const std::vector& samples) {
double gini = 1.0;
double total = 0.0;
for (double sample : samples) {
total += sample;
}
for (double sample : samples) {
double p = sample / total;
gini -= p * p;
}
return gini;
}
template
void _gen_cart_tree(dt_node* p_cur_node, const std::vector >& S, v_f_pc& vpc, cc_t& cc)
{
int i_cur_lbl = 0;
if (same_class(i_cur_lbl, S, cc))
{
p_cur_node->lbl = i_cur_lbl;
p_cur_node->is_leave = true;
return;
}
using mt = mat;
int i_max_pc_idx = 0, idx = 0;
double max_gini_gain = -1e10;
std::map > mp_max_sub; // 最大基尼系数的分类方式
std::map > mp_sub;
for(auto pc: vpc)
{
mp_sub.clear();
for (auto s:S)
{
int i_class = pc(s);
if (mp_sub.count(i_class) == 0)
{
mp_sub.insert(std::make_pair(i_class, std::vector()));
}
mp_sub[i_class].push_back(s);
}
double d_all_cnt = S.size();
double d_gini_e = 0.; // 基尼系数期望
std::map mp_count;
for (auto itr = mp_sub.begin(); itr != mp_sub.end(); ++itr)
{
std::map mp_sub_count;
std::vector& vec_cur_sub = itr->second;
for (auto cur_s : vec_cur_sub)
{
int i_cur_s_class = cc(cur_s); // 判断样本的目标类别
mp_sub_count[i_cur_s_class] = (mp_sub_count.count(i_cur_s_class) == 0 ? 1 : mp_sub_count[i_cur_s_class] + 1);
mp_count[i_cur_s_class] = (mp_count.count(i_cur_s_class) == 0 ? 1 : mp_count[i_cur_s_class] + 1);
}
std::vector vec_sub_cnt(mp_sub_count.size());
std::transform(mp_sub_count.begin(), mp_sub_count.end(), vec_sub_cnt.begin(), [](const std::pair& p) { return p.second; });
double d_sub_gini = gini(vec_sub_cnt);
d_gini_e += (vec_cur_sub.size() / d_all_cnt * d_sub_gini);
}
std::vector vec_cnt(mp_count.size());
std::transform(mp_count.begin(), mp_count.end(), vec_cnt.begin(), [](const std::pair& p) { return p.second; });
double d_gini_a = gini(vec_cnt);
double d_gini_gain = d_gini_a - d_gini_e;
if (max_gini_gain < d_gini_gain)
{
mp_max_sub = mp_sub;
max_gini_gain = d_gini_gain;
i_max_pc_idx = idx;
}
idx++;
}
p_cur_node->idx = i_max_pc_idx;
if (mp_max_sub.size() == 1)
{
p_cur_node->is_leave = true;
// 找到最大的概率,判断其类型
std::map mp_sub_count;
std::vector& vec_cur_sub = mp_max_sub.begin()->second;
for (auto cur_s : vec_cur_sub)
{
int i_cur_s_class = cc(cur_s); // 判断样本的目标类别
mp_sub_count[i_cur_s_class] = (mp_sub_count.count(i_cur_s_class) == 0 ? 1 : mp_sub_count[i_cur_s_class] + 1);
}
int i_max_lbl = -1;
double d_max_num = -1e10;
for (auto p : mp_sub_count)
{
if (p.second > d_max_num)
{
d_max_num = p.second;
i_max_lbl = p.first;
}
}
p_cur_node->lbl = i_max_lbl;
p_cur_node->rate = (d_max_num) / vec_cur_sub.size();
return;
}
for (auto itr = mp_max_sub.begin(); itr != mp_max_sub.end(); ++itr) // 循环判断子集合的决策树
{
struct dt_node* p_sub_node = new struct dt_node(); // 创建一个新的节点
_gen_cart_tree(p_sub_node, itr->second, vpc, cc); // 生成子数据集的决策树
p_cur_node->mp_sub.insert(std::make_pair(itr->first, p_sub_node)); // 将子决策树加到当前决策树的下面
}
}
template
dt_node* gen_cart_tree(const std::vector >& vdata, v_f_pc& vpc, cc_t& cc)
{
struct dt_node* p_tree = new struct dt_node();
_gen_cart_tree(p_tree, vdata, vpc, cc);
return p_tree;
}
template
std::tuple judge_cart(struct dt_node* p_cur_node, const mat& data, v_f_pc& vpc, const int& def_value)
{
if (p_cur_node->is_leave)
{
return std::tie(p_cur_node->lbl, p_cur_node->rate);
}
int i_next_idx = vpc[p_cur_node->idx](data);
if (p_cur_node->mp_sub.count(i_next_idx) == 0) // 之前训练时候没有遇到过的分类
{
return std::tuple(def_value, 1.);
}
return judge_cart(p_cur_node->mp_sub[i_next_idx], data, vpc, def_value);
}
#endif
试验代码如下:
#include "cart_t.hpp"
/* CART决策树 */
#include
template
std::function&)> gen_pc(const int& idx)
{
return [idx](const mat& mt)->int
{
return mt[idx];
};
}
int cc(const mat<4, 1, double>& mt)
{
return mt[3];
}
int main(int argc, char** argv)
{
std::vector > vec_dat;
vec_dat.push_back({ -1, 1, -1, -1 });
vec_dat.push_back({ 1, -1, 1, 1 });
vec_dat.push_back({ 1, -1, -1, 1 });
vec_dat.push_back({ -1, -1, -1, -1 });
vec_dat.push_back({ -1, -1, 1, 1 });
vec_dat.push_back({ -1, 1, 1, -1 });
vec_dat.push_back({ 1, 1, 1, -1 });
vec_dat.push_back({ 1, -1, -1, 1 });
vec_dat.push_back({ -1, 1, -1, -1 });
vec_dat.push_back({ 1, -1, 1, 1 });
std::vector&)> > vpc = { gen_pc<4>(0), gen_pc<4>(1), gen_pc<4>(2), gen_pc<4>(3) };
auto p_tree = gen_cart_tree(vec_dat, vpc, cc);
for (auto itr = vec_dat.begin(); itr != vec_dat.end(); ++itr)
{
int i_id3_class = judge_cart(p_tree, *itr, vpc, -2);
printf("CART:%d\tLABEL:%d\r\n", i_id3_class, cc(*itr));
}
return 0;
}
得到结果如下:
结果完全正确。