决策树
信息增益
#include
#include
#include
#include
using namespace std;
enum{ color, root, sound, texture, umbilical, touch };
class Object
{
protected:
vector<int> _feature;
bool _label;
public:
int num_feature;
const vector<int>* feature;
const bool* label;
};
class Cucumber : public Object
{
public:
Cucumber(vector<int> feature, bool label)
{
this->_feature = feature;
this->_label = label;
num_feature = feature.size() - 1;
this->feature = &_feature;
this->label = &_label;
}
};
class Entropy
{
public:
static double calculate(vector<Cucumber> v)
{
try {
if (!v.size()) {
throw "no object";
}
int positive = 0, negative = 0;
for (auto i : v) {
if ((*i.label)) {
positive++;
}
else {
negative++;
}
}
int total = v.size();
double w1 = positive * 1.0 / total, w2 = negative * 1.0 / total;
if (0 == w1) return -w2 * log2(w2);
if (0 == w2) return -w1 * log2(w1);
return -w1 * log2(w1) - w2 * log2(w2);
}
catch(exception e) {
cout << e.what();
}
}
};
class Classify
{
public:
static vector<vector<Cucumber>> do_classify(vector<Cucumber> v, int rank_feature)
{
try
{
if (!v.size()) {
throw "no object";
}
if (v[0].num_feature < rank_feature) {
throw "rank is out of range";
}
vector<vector<Cucumber>> res;
for (auto i : v) {
while (res.size() < (*i.feature)[rank_feature]+1){
res.push_back(*new vector<Cucumber > ());
}
res[(*i.feature)[rank_feature]].push_back(i);
}
return res;
}
catch (const std::exception& e)
{
cout << e.what();
}
}
};
class InformationGain
{
public:
static double calculate(vector<vector<Cucumber>> v, int total, double entropy_s)
{
double res = 0;
for (auto i : v) {
res += double(i.size()) * Entropy::calculate(i) / total;
}
return entropy_s - res;
}
};
int main()
{
vector<Cucumber> v;
v.push_back(*new Cucumber({ 1, 0, 0, 0, 0, 0, 0 }, true));
v.push_back(*new Cucumber({ 2, 1, 0, 1, 0, 0, 0 }, true));
v.push_back(*new Cucumber({ 3, 1, 0, 0, 0, 0, 0 }, true));
v.push_back(*new Cucumber({ 4, 0, 0, 1, 0, 0, 0 }, true));
v.push_back(*new Cucumber({ 5, 2, 0, 0, 0, 0, 0 }, true));
v.push_back(*new Cucumber({ 6, 0, 1, 0, 0, 1, 1 }, true));
v.push_back(*new Cucumber({ 7, 1, 1, 0, 1, 1, 1 }, true));
v.push_back(*new Cucumber({ 8, 1, 1, 0, 0, 1, 0 }, true));
v.push_back(*new Cucumber({ 9, 1, 1, 1, 1, 1, 0 }, false));
v.push_back(*new Cucumber({ 10, 0, 2, 2, 0, 2, 1 }, false));
v.push_back(*new Cucumber({ 11, 2, 2, 2, 2, 2, 0 }, false));
v.push_back(*new Cucumber({ 12, 2, 0, 0, 2, 2, 1 }, false));
v.push_back(*new Cucumber({ 13, 0, 1, 0, 1, 0, 0 }, false));
v.push_back(*new Cucumber({ 14, 2, 1, 1, 1, 0, 0 }, false));
v.push_back(*new Cucumber({ 15, 1, 1, 0, 0, 1, 1 }, false));
v.push_back(*new Cucumber({ 16, 2, 0, 0, 2, 2, 0 }, false));
v.push_back(*new Cucumber({ 17, 0, 0, 1, 1, 1, 0 }, false));
double entropy_s = Entropy::calculate(v);
for (int i = 1; i < 7; i++) {
vector<vector<Cucumber>> tmp = Classify::do_classify(v, i);
double ig = InformationGain::calculate(tmp, 17, entropy_s);
printf("%d : %.3f\n", i, ig);
}
cin.get();
}