计算信息增益

# -*- coding: UTF-8 -*-
from math import log
from collections import Counter
import csv
import numpy as np
 

def createDataSet():
    dataSet = np.array([['年龄', '有工作', '有自己的房子', '信贷情况','vqa'] ,
               [0, 0, 0, 0, 'no'], 
               [0, 0, 0, 1, 'no'],
               [0, 1, 0, 1, 'yes'],
               [0, 1, 1, 0, 'yes'],
               [0, 0, 0, 0, 'no'],
               [1, 0, 0, 0, 'no'],
               [1, 0, 0, 1, 'no'],
               [1, 1, 1, 1, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [1, 0, 1, 2, 'yes'],
               [2, 0, 1, 2, 'yes'],
               [2, 0, 1, 1, 'yes'],
               [2, 1, 0, 1, 'yes'],
               [2, 1, 0, 2, 'yes'],
               [2, 0, 0, 0, 'no']])
    return dataSet
 
def calcShannonEnt(dataSet,axis=-1):                      
    numEntires = len(dataSet) 
    columnCounter = Counter(dataSet[:,axis])
    shannonEnt = 0.0                                
    for key in columnCounter.keys():                       
        prob = float(columnCounter[key]) / numEntires  
        shannonEnt -= prob * log(prob, 2)           
    return shannonEnt                           

def subDataSet(dataSet, axis, value):
    numEntires = len(dataSet) 
    subDataSetIndexs = np.where(dataSet[:,axis]==value)
    subDataSet = dataSet[subDataSetIndexs,:]
    subDataSet = subDataSet[0]
    return subDataSet                                   
 
def entropyGain(dataSet,axis=1,baseAxis=-1):
    numEntires = len(dataSet) 
    EntropyGain = calcShannonEnt(dataSet,baseAxis)
    columnCounter = Counter(dataSet[:,axis])
    newEntropy = 0.0
    for key in columnCounter.keys():
        prob = float(columnCounter[key]) / numEntires
        subSet = subDataSet(dataSet=dataSet,axis=axis,value=key)
        newEntropy += prob * calcShannonEnt(dataSet=subSet,axis=baseAxis)
    EntropyGain = EntropyGain - newEntropy
    return EntropyGain
 
 
dataset = createDataSet()
dataset = dataset[1:,:]
Entropy=calcShannonEnt(dataset)
print('Entropy is {Entropy:0.6f}'.format(Entropy=Entropy))    


for i in range(4):
    EntropyGain = entropyGain(dataset,axis=i,baseAxis=4)
    print('EntropyGain is {EntropyGain}'.format(EntropyGain=EntropyGain))    

Entropy is 0.970951
EntropyGain is 0.0830074998558
EntropyGain is 0.323650198152
EntropyGain is 0.419973094022
EntropyGain is 0.362989562537

你可能感兴趣的:(机器学习)