算法原理
决策树(Decision Tree)是一种简单但广泛使用的分类器。通过训练数据构建决策树,可以高效的对未知的数据进行分类。
决策数有两大优点:
1)决策树模型可读性好,具有描述性,有助于人工分析;
2)效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度
一棵决策树的生成过程主要分为以下3个部分:
特征选择:特征选择是指从训练数据中众多的特征中选择一个特征作为当前节点的分裂标准,如何选择特征有着很多不同量化评估标准标准,从而衍生出不同的决策树算法。
决策树生成: 根据选择的特征评估标准,从上至下递归地生成子节点,直到数据集不可分则停止决策树停止生长。 树结构来说,递归结构是最容易理解的方式。
构建的基本步骤:
1. 开始,所有记录看作一个节点
2. 遍历每个变量的每一种分割方式,找到最好的分割点
3. 分割成两个节点N1和N2
4. 对N1和N2分别继续执行2-3步,直到每个节点足够“纯”为止
决策树的变量可以有两种:
1) 数字型(Numeric):变量类型是整数或浮点数;
2) 名称型(Nominal):类似编程语言中的枚举类型;
如何评估分割点的好坏?如果一个分割点可以将当前的所有节点分为两类,使得每一类都很“纯”,也就是同一类的记录较多,那么就是一个好分割点。构建决策树采用贪心算法,只考虑当前纯度差最大的情况作为分割点。
伪代码实现:# ==============================================
# 输入:
# 数据集
# 输出:
# 构造好的决策树(也即训练集)
# ==============================================
def 创建决策树:
'创建决策树'
if (数据集中所有样本分类一致):
创建携带类标签的叶子节点
else:
寻找划分数据集的最好特征
根据最好特征划分数据集
for 每个划分的数据集:
创建决策子树(递归方式)
从上述步骤可以看出,决策生成过程中有两个重要的问题:
(1)数据如何分割
(2)如何选择分裂的属性
(3)什么时候停止分裂
量化纯度
决策树是根据“纯度”来构建的,这里介绍三种纯度计算方法。如果记录被分为n类,每一类的比例P(i)=第i类的数目/总数目。
Gini不纯度 | 熵(Entropy) | 错误率 |
![]() |
![]() |
![]() |
上面的三个公式均是值越大,表示越 “不纯”,越小表示越“纯”。三种公式只需要取一种即可,实践证明三种公式的选择对最终分类准确率的影响并不大,一般使用熵公式。
信息增益(Information Gain)(纯度差):当前节点的不纯度减去子节点不纯度的加权平均数,权重由子节点记录数与当前节点记录数的比例决定
其中,I:不纯度,K:分割的节点数(K = 2)。vj :子节点中的记录数目。。
ID3算法用的是信息增益,C4.5算法用信息增益率;CART算法使用基尼系数
ID3的缺点,倾向于选择水平数量较多的变量,可能导致训练得到一个庞大且深度浅的树;另外输入变量必须是分类变量(连续变量必须离散化);最后无法处理空值。
C4.5选择了信息增益率替代信息增益。
CART以基尼系数替代熵;最小化不纯度而不是最大化信息增益。
例:以熵作为节点复杂度的统计量,分别求出下面例子的信息增益,图3.1表示节点选择属性1进行分裂的结果,图3.2表示节点选择属性2进行分裂的结果,通过计算两个属性分裂后的信息增益,选择最优的分裂属性。
属性1: | 属性2: |
![]() |
![]() |
由于 ,所以属性1与属性2相比是更优的分裂属性,故选择属性1作为分裂的属性。
(2)信息增益率
使用信息增益作为选择分裂的条件有一个不可避免的缺点:倾向选择分支比较多的属性进行分裂。为了解决这个问题,引入了信息增益率这个概念。信息增益率是在信息增益的基础上除以分裂节点数据量的信息增益(听起来很拗口),其计算公式如下:
其中Info_Gain表示信息增益,InstrinsicInfo表示分裂子节点数据量的信息增益,其计算公式为:
其中m表示子节点的数量,表示第i个子节点的数据量,N表示父节点数据量, 其实InstrinsicInfo是分裂节点的熵,如果节点的数据链越接近,InstrinsicInfo越大,如果子节点越大,InstrinsicInfo越大,而Info_Ratio就会越小,能够降低节点分裂时选择子节点多的分裂属性的倾向性。信息增益率越高,说明分裂的效果越好。
还是信息增益中提及的例子为例:
属性1的信息增益率 | 属性2的信息增益率 |
![]() |
![]() |
由于Info_Ratio2>Info_Ratio1 ,故选择属性2作为分裂的属性。
停止条件
决策树的构建过程是一个递归的过程,所以需要确定停止条件。
(1)、一种最直观的方式是当每个子节点只有一种类型的记录时停止,但是这样往往会使得树的节点过多,导致过拟合(Overfitting);
(2)、另一种可行的方法是当前节点中的记录数低于一个最小的阀值,那么就停止分割,将max(P(i))对应的分类作为当前叶节点的分类。
过渡拟合
采用上面算法生成的决策树在事件中往往会导致过滤拟合。原因有以下几点:
优化方案1:修剪枝叶
决策树过渡拟合往往是因为节点过多,所以需要裁剪(Prune Tree)枝叶。裁剪枝叶的策略对决策树正确率的影响很大。主要有两种裁剪策略。
前置裁剪 在构建决策树的过程时,提前停止。那么,会将切分节点的条件设置的很苛刻,导致决策树很短小。结果就是决策树无法达到最优。
后置裁剪 决策树构建好后,然后才开始裁剪。采用两种方法:1)用单一叶节点代替整个子树,叶节点的分类采用子树中最主要的分类;2)将一个字数完全替代另外一颗子树。后置裁剪有个问题就是计算效率,有些节点计算后就被裁剪了,导致有点浪费。
优化方案2:K-Fold Cross Validation
首先计算出整体的决策树T,叶节点个数记作N,设i属于[1,N]。对每个i,使用K-Fold Validataion方法计算决策树,并裁剪到i个节点,计算错误率,最后求出平均错误率。这样可以用具有最小错误率对应的i作为最终决策树的大小,对原始决策树进行裁剪,得到最优决策树。
优化方案3:Random Forest
Random Forest是用训练数据随机的计算出许多决策树,形成了一个森林。然后用这个森林对未知数据进行预测,选取投票最多的分类。实践证明,此算法的错误率得到了进一步的降低。一颗树预测正确的概率可能不高,但是集体预测正确的概率却很高。
准确率估计
决策树T构建好后,需要估计预测准确率。直观说明,比如N条测试数据,X预测正确的记录数,那么可以估计acc = X/N为T的准确率。但是,这样不是很科学。因为我们是通过样本估计的准确率,很有可能存在偏差。所以,比较科学的方法是估计一个准确率的区间,这里就要用到统计学中的置信区间(Confidence Interval)。
设T的准确率p是一个客观存在的值,X的概率分布为X ~ B(N,p),即X遵循概率为p,次数为N的二项分布(Binomial Distribution),期望E(X) = N*p,方差Var(X) = N*p*(1-p)。由于当N很大时,二项分布可以近似有正太分布(Normal Distribution)计算,一般N会很大,所以X ~ N(np,n*p*(1-p))。可以算出,acc = X/N的期望E(acc) = E(X/N) = E(X)/N = p,方差Var(acc) = Var(X/N) = Var(X) / N2 = p*(1-p) / N,所以acc ~ N(p,p*(1-p)/N)。这样,就可以通过正太分布的置信区间的计算方式计算执行区间了。
正太分布的置信区间求解如下:
1) 将acc标准化,即
2) 选择置信水平α= 95%,或其他值,这取决于你需要对这个区间有多自信。一般来说,α越大,区间越大。
3) 求出 α/2和1-α/2对应的标准正太分布的统计量 和
(均为常量)。然后解下面关于p的不等式。acc可以有样本估计得出。即可以得到关于p的执行区间。
代码练习1:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#计算给定数据集的熵
#导入log运算符
import operator
import matplotlib.pyplot as plt
from math import log
import sys
reload(sys)
sys.setdefaultencoding('utf8')
def calEntropy(dataSet):
#获取数据集的行数
numEntries=len(dataSet)
#设置字典的数据结构
labelCounts={}
#提取数据集的每一行的特征向量
for featureVecor in dataSet:
#获取特征向量的最后一列的标签
currentLabel=featureVecor[-1]
#检测字典的关键字key中是否存在该标签
#如果不存在keys()关键字
if currentLabel not in labelCounts.keys():
#将当前标签/0键值对存入字典中
labelCounts[currentLabel]=0
#否则将当前标签对应的键值加1
labelCounts[currentLabel]+=1
#初始化熵为0
Entropy=0.0
#对于数据集中所有的分类类别
for key in labelCounts:
#计算各个类别出现的频率
prob=float(labelCounts[key])/numEntries
#计算各个类别信息期望值
Entropy-=prob*log(prob,2)
#返回信息熵
return Entropy
#创建一个简单的数据集
#数据集中包含两个特征'height','sex';
#数据的类标签有两个'yes','no'
def creatDataSet():
dataSet=[[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels=['height','sex']
#返回数据集和类标签
return dataSet,labels
#划分数据集:按照最优特征划分数据集
#@dataSet:待划分的数据集
#@axis:划分数据集的特征
#@value:特征的取值
def splitDataSet(dataSet,axis,value):
'''需要说明的是,python语言传递参数列表时,传递的是列表的引用
如果在函数内部对列表对象进行修改,将会导致列表发生变化,为了
不修改原始数据集,创建一个新的列表对象进行操作'''
retDataSet=[]
#提取数据集的每一行的特征向量
for featVec in dataSet:
#针对axis特征不同的取值,将数据集划分为不同的分支
#如果该特征的取值为value
if featVec[axis]==value:
#将特征向量的0~axis-1列存入列表reducedFeatVec
reducedFeatVec=featVec[:axis]
#将特征向量的axis+1~最后一列存入列表reducedFeatVec
#extend()是将另外一个列表中的元素(以列表中元素为对象)一一添加到当前列表中,构成一个列表
#比如a=[1,2,3],b=[4,5,6],则a.extend(b)=[1,2,3,4,5,6]
reducedFeatVec.extend(featVec[axis+1:])
#简言之,就是将原始数据集去掉当前划分数据的特征列
#append()是将另外一个列表(以列表为对象)添加到当前列表中
##比如a=[1,2,3],b=[4,5,6],则a.extend(b)=[1,2,3,[4,5,6]]
retDataSet.append(reducedFeatVec)
return retDataSet
#如何选择最好的划分数据集的特征
#使用某一特征划分数据集,信息增益最大,则选择该特征作为最优特征
def chooseBestFeatureToSplit(dataSet):
#获取数据集特征的数目(不包含最后一列的类标签)
numFeatures=len(dataSet[0])-1
#计算未进行划分的信息熵
baseEntropy=calEntropy(dataSet)
#最优信息增益 最优特征
bestInfoGain=0.0;bestFeature=-1
#利用每一个特征分别对数据集进行划分,计算信息增益
for i in range(numFeatures):
#得到特征i的特征值列表
featList=[example[i] for example in dataSet]
#利用set集合的性质--元素的唯一性,得到特征i的取值
uniqueVals=set(featList)
#信息增益0.0
newEntropy=0.0
#对特征的每一个取值,分别构建相应的分支
for value in uniqueVals:
#根据特征i的取值将数据集进行划分为不同的子集
#利用splitDataSet()获取特征取值Value分支包含的数据集
subDataSet=splitDataSet(dataSet,i,value)
#计算特征取值value对应子集占数据集的比例
prob=len(subDataSet)/float(len(dataSet))
#计算占比*当前子集的信息熵,并进行累加得到总的信息熵
newEntropy+=prob*calEntropy(subDataSet)
#计算按此特征划分数据集的信息增益
#公式特征A,数据集D
#则H(D,A)=H(D)-H(D/A)
infoGain=baseEntropy-newEntropy
#比较此增益与当前保存的最大的信息增益
if (infoGain>bestInfoGain):
#保存信息增益的最大值
bestInfoGain=infoGain
#相应地保存得到此最大增益的特征i
bestFeature=i
#返回最优特征
return bestFeature
#当遍历完所有的特征属性后,类标签仍然不唯一(分支下仍有不同分类的实例)
#采用多数表决的方法完成分类
def majorityCnt(classList):
#创建一个类标签的字典
classCount={}
#遍历类标签列表中每一个元素
for vote in classList:
#如果元素不在字典中
if vote not in classCount.keys():
#在字典中添加新的键值对
classCount[vote]=0
#否则,当前键对于的值加1
classCount[vote]+=1
#对字典中的键对应的值所在的列,按照又大到小进行排序
#@classCount.items 列表对象
#@key=operator.itemgetter(1) 获取列表对象的第一个域的值
#@reverse=true 降序排序,默认是升序排序
sortedClassCount=sorted(classCount.items,\
key=operator.itemgetter(1),reverse=True)
#返回出现次数最多的类标签
return sortedClassCount[0][0]
# 创建树
def createTree(dataSet, labels):
# 获取数据集中的最后一列的类标签,存入classList列表
classList = [example[-1] for example in dataSet]
# 通过count()函数获取类标签列表中第一个类标签的数目
# 判断数目是否等于列表长度,相同表面所有类标签相同,属于同一类
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有的特征属性,此时数据集的列为1,即只有类标签列
if len(dataSet[0]) == 1:
# 多数表决原则,确定类标签
return majorityCnt(classList)
# 确定出当前最优的分类特征
bestFeat = chooseBestFeatureToSplit(dataSet)
# 在特征标签列表中获取该特征对应的值
bestFeatLabel = labels[bestFeat]
# 采用字典嵌套字典的方式,存储分类树信息
myTree = {bestFeatLabel: {}}
# 复制当前特征标签列表,防止改变原始列表的内容
subLabels = labels[:]
# 删除属性列表中当前分类数据集特征
del (subLabels[bestFeat])
# 获取数据集中最优特征所在列
featValues = [example[bestFeat] for example in dataSet]
# 采用set集合性质,获取特征的所有的唯一取值
uniqueVals = set(featValues)
# 遍历每一个特征取值
for value in uniqueVals:
'''
采用递归的方法利用该特征对数据集进行分类
@bestFeatLabel 分类特征的特征标签值
@dataSet 要分类的数据集
@bestFeat 分类特征的标称值
@value 标称型特征的取值
@subLabels 去除分类特征后的子特征标签列表
'''
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
#------------------------测试算法------------------------------
#完成决策树的构造后,采用决策树实现具体应用
#@intputTree 构建好的决策树
#@featLabels 特征标签列表
#@testVec 测试实例
def classify(inputTree,featLabels,testVec):
#找到树的第一个分类特征,或者说根节点'no surfacing'
#注意python2.x和3.x区别,2.x可写成firstStr=inputTree.keys()[0]
#而不支持3.x
firstStr=list(inputTree.keys())[0]
#从树中得到该分类特征的分支,有0和1
secondDict=inputTree[firstStr]
#根据分类特征的索引找到对应的标称型数据值
#'no surfacing'对应的索引为0
featIndex=featLabels.index(firstStr)
#遍历分类特征所有的取值
for key in secondDict.keys():
#测试实例的第0个特征取值等于第key个子节点
if testVec[featIndex]==key:
#type()函数判断该子节点是否为字典类型
if type(secondDict[key]).__name__=='dict':
#子节点为字典类型,则从该分支树开始继续遍历分类
classLabel=classify(secondDict[key],featLabels,testVec)
#如果是叶子节点,则返回节点取值
else: classLabel=secondDict[key]
return classLabel
'''def testDataSet():
dataSet = [[1, 1 ],[1, 1],[1, 0],[0, 1],[0, 1]]
# 返回数据集
return dataSet
'''
#决策树的存储:python的pickle模块序列化决策树对象,使决策树保存在磁盘中
#在需要时读取即可,数据集很大时,可以节省构造树的时间
#pickle模块存储决策树
def storeTree(inputTree,filename):
#导入pickle模块
import pickle
#创建一个可以'写'的文本文件
#这里,如果按树中写的'w',将会报错write() argument must be str,not bytes
#所以这里改为二进制写入'wb'
fw=open(filename,'wb')
#pickle的dump函数将决策树写入文件中
pickle.dump(inputTree,fw)
#写完成后关闭文件
fw.close()
#取决策树操作
def grabTree(filename):
import pickle
#对应于二进制方式写入数据,'rb'采用二进制形式读出数据
fr=open(filename,'rb')
return pickle.load(fr)
#-------------------------------------------绘制------------------------
# ===============================================
# 输入:
# myTree: 决策树
# 输出:
# numLeafs: 决策树的叶子数
# ===============================================
def getNumLeafs(myTree):
'计算决策树的叶子数'
# 叶子数
numLeafs = 0
# 节点信息
firstStr = myTree.keys()[0]
# 分支信息
secondDict = myTree[firstStr]
for key in secondDict.keys(): # 遍历所有分支
# 子树分支则递归计算
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
# 叶子分支则叶子数+1
else:
numLeafs += 1
return numLeafs
# ===============================================
# 输入:
# myTree: 决策树
# 输出:
# maxDepth: 决策树的深度
# ===============================================
def getTreeDepth(myTree):
'计算决策树的深度'
# 最大深度
maxDepth = 0
# 节点信息
firstStr = myTree.keys()[0]
# 分支信息
secondDict = myTree[firstStr]
for key in secondDict.keys(): # 遍历所有分支
# 子树分支则递归计算
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
# 叶子分支则叶子数+1
else:
thisDepth = 1
# 更新最大深度
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# ==================================================
# 输入:
# nodeTxt: 终端节点显示内容
# centerPt: 终端节点坐标
# parentPt: 起始节点坐标
# nodeType: 终端节点样式
# 输出:
# 在图形界面中显示输入参数指定样式的线段(终端带节点)
# ==================================================
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
'画线(末端带一个点)'
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# =================================================================
# 输入:
# cntrPt: 终端节点坐标
# parentPt: 起始节点坐标
# txtString: 待显示文本内容
# 输出:
# 在图形界面指定位置(cntrPt和parentPt中间)显示文本内容(txtString)
# =================================================================
def plotMidText(cntrPt, parentPt, txtString):
'在指定位置添加文本'
# 中间位置坐标
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
# ===================================
# 输入:
# myTree: 决策树
# parentPt: 根节点坐标
# nodeTxt: 根节点坐标信息
# 输出:
# 在图形界面绘制决策树
# ===================================
def plotTree(myTree, parentPt, nodeTxt):
'绘制决策树'
# 当前树的叶子数
numLeafs = getNumLeafs(myTree)
# 当前树的节点信息
firstStr = myTree.keys()[0]
# 定位第一棵子树的位置(这是蛋疼的一部分)
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
# 绘制当前节点到子树节点(含子树节点)的信息
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
# 获取子树信息
secondDict = myTree[firstStr]
# 开始绘制子树,纵坐标-1。
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys(): # 遍历所有分支
# 子树分支则递归
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
# 叶子分支则直接绘制
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# 子树绘制完毕,纵坐标+1。
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# ==============================
# 输入:
# myTree: 决策树
# 输出:
# 在图形界面显示决策树
# ==============================
def createPlot(inTree):
'显示决策树'
# 创建新的图像并清空 - 无横纵坐标
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
# 树的总宽度 高度
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
# 当前绘制节点的坐标
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0;
# 绘制决策树
plotTree(inTree, (0.5, 1.0), '')
plt.show()
if __name__ == '__main__':
myDat,labels=creatDataSet()
myTree = createTree(myDat, labels)
print(myTree)
storeTree(myTree, 'newTxt')
myStoreTree= grabTree('newTxt')
classLabel1 = classify(myTree,labels,[0,1])
print(classLabel1)
classLabel2 = classify(myStoreTree, labels, [0, 1])
print(classLabel2)
createPlot(myTree)
代码练习2:
#!/usr/local/bin/python
# -*- coding : utf-8 -*-
import sys
import os
import math
import argparse
did2label = {}
wid2word = {}
didwordlist = {}
widdoclist = {}
def load_train_data( file_path ):
fp = open(file_path)
did = 0
word_idx = {}
wid = 0
doc_list = set()
while True :
line = fp.readline()
if len(line) <= 0 :
break
arr = line.strip('\r\n').split('\t')
label = int(arr[0])
did2label[did] = label
didwordlist[did] = set()
for w in arr[1:]:
if len(w) <= 3 :
continue
if w not in word_idx:
word_idx[w] = wid
wid2word[wid] = w
widdoclist[wid] = set()
wid += 1
widdoclist[word_idx[w]].add(did)
didwordlist[did].add(word_idx[w])
doc_list.add(did)
did += 1
return doc_list
def entropy( num, den ):
if num == 0 :
return 0
p = float(num)/float(den)
return -p*math.log(p,2)
class DecisionTree :
def __init__(self) :
self.word = None
self.doc_count = 0
self.positive = 0
self.negative = 0
self.child = {}
def predict(self, word_list ):
if len(self.child) == 0 :
return float(self.positive)/(self.positive+self.negative)
if self.word in word_list :
return self.child["left"].predict(word_list)
else :
return self.child["right"].predict(word_list)
def visualize(self, d) :
"visualize the tree"
for i in range (0, d) :
print "-",
print "(%s,%d,%d)" % ( self.word,self.positive, self.negative)
if len(self.child) != 0 :
self.child["left"].visualize(d + 1)
self.child["right"].visualize(d + 1)
def build_dt(self, doc_list ) :
self.doc_count = len(doc_list)
for did in doc_list :
if did2label[did] > 0 :
self.positive += 1
else :
self.negative += 1
if self.doc_count <= 10 or self.positive * self.negative == 0 :
return True
wid = info_gain( doc_list )
if wid == -1 :
return True
self.word = wid2word[wid]
left_list = set()
right_list = set()
for did in doc_list :
if did in widdoclist[wid] :
left_list.add(did)
else :
right_list.add(did)
self.child["left"] = DecisionTree()
self.child["right"] = DecisionTree()
self.child["left"].build_dt( left_list )
self.child["right"].build_dt(right_list )
def info_gain(doc_list):
collect_word = set()
total_positive = 0
total_negative = 0
for did in doc_list :
for wid in didwordlist[did] :
collect_word.add(wid)
if did2label[did] > 0 :
total_positive += 1
else :
total_negative += 1
total = len(doc_list)
info = entropy( total_positive, total )
info += entropy( total_negative, total )
ig = []
for wid in collect_word :
positive = 0
negative = 0
for did in widdoclist[wid]:
if did not in doc_list :
continue
if did2label[did] > 0 :
positive += 1
else :
negative += 1
df = negative + positive
a = info
b = entropy( positive, df )
b += entropy( negative, df )
a -= b * df / total
b = entropy( total_positive - positive, total - df)
b += entropy( total_negative - negative, total - df )
a -= b * ( total - df ) / total
a = a * 100000.0
ig.append( (a, wid))
ig.sort()
ig.reverse()
for i,wid in ig :
left = 0
right = 0
for did in doc_list :
if did in widdoclist[wid] :
left += 1
else :
right += 1
if left >= 5 and right >= 5 :
return wid
return -1
if __name__ == "__main__" :
parser = argparse.ArgumentParser( description = "Decision Tree training and testing" )
parser.add_argument( "-i", "--train_data", help = "training data")
parser.add_argument( "-t", "--test_data", help = "testing data")
args = parser.parse_args()
train_file = args.train_data
test_file = args.test_data
if not train_file or not os.path.exists(train_file) :
parser.print_help()
sys.exit()
if not test_file or not os.path.exists(test_file) :
parser.print_help()
sys.exit()
doc_list = load_train_data( train_file )
dt = DecisionTree()
dt.build_dt(doc_list)
#dt.visualize(0)
fp = open(test_file)
true_positive = 0
false_positive = 0
positive = 0
true_negative = 0
false_negative = 0
negative = 0
total = 0
while True :
line = fp.readline()
if len( line ) <= 0 :
break
arr = line.strip('\r\n').split('\t')
label = int(arr[0])
word_list = set()
for w in arr[1:] :
if len(w) <= 3 :
continue
word_list.add( w )
p = dt.predict(word_list)
print label, p
if label == 1 :
positive += 1
else :
negative += 1
if p >= 0.5 :
if label == 1 :
true_positive += 1
else :
false_positive += 1
else :
negative += 1
if label == -1 :
true_negative += 1
else :
false_negative += 1
total += 1
print "Positive recall :%f" % (true_positive*100.0/(positive))
print "Positive precision :%f" % (true_positive*100.0/(true_positive+false_positive))
print "Accuary : %f" % ( (true_positive + true_negative)*100.0/total)
常用的决策树算法有ID3,C4.5,CART三种。3种算法的模型构建思想都十分类似,只是采用了不同的指标。决策树模型的构建过程大致如下:
输入:训练集D,特征集A,阈值eps 输出:决策树T
这里只简单介绍下CART与ID3和C4.5的区别。
相关参考资料:
算法:http://www.cnblogs.com/muchen/p/6141978.html
算法:http://www.cnblogs.com/zy230530/p/6813250.html
代码形成过程:http://blog.csdn.net/zzxvictory/article/details/73250685
计算过程:http://www.cnblogs.com/yonghao/p/5061873.html