第 8 章中介绍了线性回归的一些强大的方法,但这些方法创建的模型需要拟合所有的样本点(局部加权线性回归除外)。当数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型的想法就显得太难了,也略显笨拙。而且,实际生活中很多问题都是非线性的,不可能使用全局线性模型来拟合任何数据。
一种可行的方法是将数据集切分成很多份易建模的数据,然后利用我们的线性回归技术来建模。如果首次切分后仍然难以拟合线性模型就继续切分。在这种切分方式下,树回归和回归法就相当有用。
本章介绍 CART(Classification And Regression Trees, 分类回归树) 的树构建算法。该算法既可以用于分类还可以用于回归。
第3章 中使用的树构建算法是 ID3 。ID3 的做法是每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来切分。也就是说,如果一个特征有 4 种取值,那么数据将被切分成 4 份。一旦按照某特征切分后,该特征在之后的算法执行过程中将不会再起作用,所以有观点认为这种切分方式过于迅速。另外一种方法是二分切分法,即每次把数据集切分成两份。如果数据的某特征值等于切分所要求的值,那么这些数据就进入树的左子树,反之则进入树的右子树。
除了切分过于迅速外, ID3 算法还存在另一个问题,它不能直接处理连续型特征。只有事先将连续型特征转换成离散型,才能在 ID3 算法中使用。但这种转换过程会破坏连续型变量的内在性质。而使用二元切分法则易于对树构造过程进行调整以处理连续型特征。具体的处理方法是: 如果特征值大于给定值就走左子树,否则就走右子树。另外,二分切分法也节省了树的构建时间,但这点意义也不是特别大,因为这些树构建一般是离线完成,时间并非需要重点关注的因素。
CART 是十分著名且广泛记载的树构建算法,它使用二元切分来处理连续型变量。对 CART 稍作修改就可以处理回归问题。第 3 章中使用香农熵来度量集合的无组织程度。如果选用其他方法来代替香农熵,就可以使用树构建算法来完成回归。
回归树与分类树的思路类似,但是叶节点的数据类型不是离散型,而是连续型。
CART算法原理:
假设X与Y分别为输入和输出变量,并且Y是连续变量,给定训练数据集:
其中,D表示整个数据集合,n为特征数。
一个回归树对应着输入空间(即特征空间)的一个划分以及在划分的单元上的输出值。假设已将输入空间划分为M个单元R1,R2,…Rm,并且在每个单元Rm上有一个固定的输出值Cm,于是回归树模型可表示为:
这样就可以计算模型输出值与实际值的误差:
我们希望每个单元上的Cm,可以是的这个平方误差最小化。易知,当Cm为相应单元的所有实际值的均值时,可以到最优:
那么如何生成这些单元划分?
假设,我们选择变量 xj 为切分变量,它的取值 s 为切分点,那么就会得到两个区域:
当j和s固定时,我们要找到两个区域的代表值c1,c2使各自区间上的平方差最小:
前面已经知道c1,c2为区间上的平均:
那么对固定的 j 只需要找到最优的s,然后通过遍历所有的变量,我们可以找到最优的j,这样我们就可以得到最优对(j,s),并得到两个区间。
这样的回归树通常称为最小二乘回归树(least squares regression tree)。
CART算法步骤:
除此之外,我们再定义两个参数,tolS和tolN,分别用于控制误差变化限制和切分特征最少样本数。这两个参数的意义是什么呢?就是防止过拟合,提前设置终止条件,实际上是在进行一种所谓的预剪枝(prepruning)操作
构建决策树算法,常用到的是三个方法: ID3, C4.5, CART。三种方法区别是划分树的分支的方式:
算法 | 分支方法 |
---|---|
ID3 | 信息增益 |
C4.5 | 信息增益比 |
CART | gini系数 |
在树的构建过程中,需要解决多种类型数据的存储问题。与第3章类似,这里讲实用一部字典来存储树的数据结构,该字典将包含以下4个元素。
这与第3章的树结构有一点不同。第3章用一部字典来存储每个切分,但该字典可以包含两个或两个以上的值。而CART算法只做二元切分,所以这里可以固定树的数据结构。树包含左键和右键,可以存储另一棵子树或者单个值。字典还包含特征和特征值这两个键,它们给出切分算法所有的特征和特征值。接下来,新建regTree.py文件,并在其中编写代码。
先载入数据集ex00.txt,然后进行可视化显示,观察数据,新建regTrees.py文件编写代码如下:
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
""" 函数说明:加载数据 Parameters: fileName:文件名 Returns: dataMat:数据矩阵 """
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine)) #转化为float类型
dataMat.append(fltLine)
return dataMat
""" 函数说明:加载数据 Parameters: fileName:文件名 Returns: 无 """
def plotDataSet(fileName):
dataMat = loadDataSet(fileName) #加载数据集
n = len(dataMat) #样本个数
xcord = []
ycord = []
for i in range(n):
xcord.append(dataMat[i][0])
ycord.append(dataMat[i][1])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(xcord, ycord, s=20, c='blue', alpha=0.5)
plt.title('DataSet')
plt.xlabel('X')
plt.show()
if __name__ == '__main__':
filename = 'ex00.txt'
plotDataSet(filename)
运行结果如下图所示:
可以看到,这是一个很简单的数据集,我们先利用这个数据集测试我们的CART算法。
创建方法很简单,我们根据切分的特征和特征值切分出两个数据集,然后将两个数据集分别用于左子树的构建和右子树的构建,直到无法找到切分的特征为止。因此,我们可以使用递归实现这个过程,在regTree.py文件中继续编写代码如下:
""" 函数说明:根据特征二元切分数据集合 Parameters: dataSet:数据集合 feature:待切分的特征 value:该特征的值 Returns: mat0: 切分的数据集合0 mat1: 切分的数据集合1 """
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
return mat0, mat1
""" 函数说明:生成叶结点 Parameters: dataSet: 数据集合 Returns: 目标变量的均值 """
def regLeaf(dataSet):
return np.mean(dataSet[:,-1])
""" 函数说明:误差估计函数 Parameters: dataSet: 数据集合 Returns: 目标变量的总方差 """
def regErr(dataSet):
return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]
""" 函数说明:找到数据的最佳二元切分方式函数 Parameters: dataSet: 数据集合 leafType: 生成叶结点 regErr: 误差估计函数 ops: 用户定义的参数构成的元组 Returns: bestIndex: 最佳切分特征 bestValue: 最佳特征值 """
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
#tolS允许的误差下限值,tolN切分的最少样本数
tolS = ops[0]
tolN = ops[1]
#如果当前所有值相等,则退出。(根据set的特性)
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
#统计数据集合的行m和列n
m, n = np.shape(dataSet)
S = errType(dataSet) #默认最后一个特征为最佳切分特征,计算其误差估计
bestS = np.inf #最佳误差
bestIndex = 0 #最佳特征切分的索引值
bestValue = 0 #最近特征值
for featIndex in range(n-1): #遍历所有特征列
for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]): #遍历所有特征值
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) #根据特征和特征值切分数据集
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): #如果数据少于tolN,则退出
continue
newS = errType(mat0) + errType(mat1) #计算误差估计
if newS < bestS: #如果误差估计更小,则更新特征索引值和特征值
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS: #如果误差减少不大则退出
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) #根据最佳的切分特征和特征值切分数据集合
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): #如果切分出的数据集很小则退出
return None, leafType(dataSet)
return bestIndex, bestValue #返回最佳切分特征和特征值
""" 函数说明:树构建函数 Parameters: dataSet - 数据集合 leafType - 建立叶结点的函数 errType - 误差计算函数 ops - 包含树构建所有其他参数的元组 Returns: retTree - 构建的回归树 """
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
#选择最佳切分特征和特征值
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None: #r如果没有特征,则返回特征值
return val
retTree = {} #初始化回归树
retTree['spInd'] = feat
retTree['spVal'] = val
#分成左数据集和右数据集
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
if __name__ == '__main__':
# filename = 'ex00.txt'
# plotDataSet(filename)
myDat = loadDataSet('ex00.txt')
myMat = np.mat(myDat)
# feat,val = chooseBestSplit(myMat, regLeaf, regErr, (1,4))
# print(feat)
# print(val)
print(createTree(myMat))
运行结果如下:
这里要特别注意误差估计函数中是如何计算总方差的。总方差可以通过均方差乘以数据集中样本点的个数来得到。
从运行结果可知,这棵树只有两个叶结点。
我们换一个复杂一点的数据集ex0.txt,分段常数数据集。先看下数据:
第一列的数据都是1.0,为了可视化方便,我们将第1列作为x轴数据,第2列作为y轴数据。对数据进行可视化,修改plotDataSet函数中红线圈起来的代码如下:
if __name__ == '__main__':
filename = 'ex0.txt'
plotDataSet(filename)
运行结果如下图所示
可以看到,这个数据集是分段的。我们针对此数据集创建回归树。构建树的代码同上,运行结果如下图所示
可以看到,该树的结构中包含5个叶结点。
现在为止,已经完成回归树的构建,但是需要某种措施来检查构建过程是否得当。这个技术就是剪枝(tree pruning)技术。
一棵树如果结点过多,表明该模型可能对数据进行了“过拟合”。
通过降低树的复杂度来避免过拟合的过程称为剪枝(pruning)。上小节我们也已经提到,设置tolS和tolN就是一种预剪枝操作。另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。接下来我们先分析后剪枝的有效性,但首先来看一下预剪枝的不足之处。
预剪枝有一定的局限性,比如我们现在使用一个新的数据集ex2.txt 。首先我们用图形来观察一下数据集,绘制出的数据集图形如下图所示
可以看到,对于这个数据集与我们使用的第一个数据集很相似,但是区别在于y的数量级差100倍,数据分布相似,因此构建出的树应该也是只有两个叶结点。但是我们使用默认tolS和tolN参数创建树,你会发现运行结果如下所示:
if __name__ == '__main__':
myDat = loadDataSet('ex2.txt')
myMat = np.mat(myDat)
print(createTree(myMat))
可以看到,构建出的树有很多叶结点。产生这个现象的原因在于,停止条件tolS对误差的数量级十分敏感。如果在选项中花费时间并对上述误差容忍度取平均值,或许也能得到仅有两个叶结点组成的树:
if __name__ == '__main__':
myDat = loadDataSet('ex2.txt')
myMat = np.mat(myDat)
print(createTree(myMat))
运行结果如下:
可以看到,将参数tolS修改为10000后,构建的树就是只有两个叶结点。然而,显然这个值,需要我们经过不断测试得来,显然通过不断修改停止条件来得到合理结果并不是很好的办法。事实上,我们常常甚至不确定到底需要寻找什么样的结果。因为对于一个很多维度的数据集,你也不知道构建的树需要多少个叶结点。
可见,预剪枝有很大的局限性。接下来,我们讨论后剪枝,即利用测试集来对树进行剪枝。由于不需要用户指定参数,后剪枝是一个更理想化的剪枝方法。
使用后剪枝方法需要将数据集分成测试集和训练集。首先指定参数,使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶结点,用测试集来判断这些叶结点合并是否能降低测试集误差。如果是的话就合并。
为了演示后剪枝,我们使用ex2.txt文件作为训练集,而使用的新数据集ex2test.txt文件作为测试集。
现在我们使用ex2.txt训练回归树,然后利用ex2test.txt对回归树进行剪枝。我们需要创建三个函数isTree()、getMean()、prune()。其中isTree()用于测试输入变量是否是一棵树,返回布尔类型的结果。换句话说,该函数用于判断当前处理的结点是否是叶结点。第二个函数getMean()是一个递归函数,它从上往下遍历树直到叶结点为止。如果找到两个叶结点则计算它们的平均值。该函数对树进行塌陷处理(即返回树平均值)。而第三个函数prune()则为后剪枝函数。编写代码如下:
""" 函数说明:后剪枝 Parameters: tree: 树 testData:测试集 Return: 树的平均值 """
def prune(tree, testData):
if np.shape(testData)[0] == 0: #如果测试集为空,则对树进行塌陷处理
return getMean(tree)
#如果有左子树或者右子树,则切分数据集
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
#处理左子树(剪枝)
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet)
#处理右子树(剪枝)
if isTree(tree['right']):
tree['right'] = prune(tree['right'],rSet)
#如果当前结点的左右结点为叶结点
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
#计算没有合并的误差
errorNoMerge = sum(np.power(lSet[:, -1] - tree['left'], 2)) +\
sum(np.power(rSet[:,-1] - tree['right'], 2))
#计算合并的均值
treeMean = (tree['left']+tree['right'])/2.0
#计算合并的误差
errorMerge = sum(np.power(testData[:,-1] - treeMean, 2))
#如果合并的误差小于没有合并的误差,则合并
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else:
return tree
else:
return tree
if __name__ == '__main__':
train_filename = 'ex2.txt'
train_Data = loadDataSet(train_filename)
train_Mat = np.mat(train_Data)
tree = createTree(train_Mat)
print(tree)
test_filename = 'ex2test.txt'
test_Data = loadDataSet(test_filename)
test_Mat = np.mat(test_Data)
print(prune(tree, test_Mat))
运行剪枝后的结果如下:
可以看到,树的大量结点已经被剪枝掉了,但没有像预期的那样剪枝成两部分,这说明后剪枝可能不如预剪枝有效。一般地,为了寻求最佳模型可以同时使用两种剪枝技术。
现在,可能你会问了,这叶结点只是简单的数值。这也没有拟合数据啊?回归树到底啥样啊?别急,下篇文章继续讲解。
用树建模,除了把叶节点简单地设定为常数值外,还可把叶节点设定为分段线性函数,这里的分段线性是指模型由多个线性片段组成。模型树的可解释性是它优于回归树的特点之一,另外,模型树还具有更高的预测准确度。
先看一下数据集exp2.txt,绘制图形如下:
if __name__ == '__main__':
plotDataSet('exp2.txt')
运行结果如下图
从上述的数据可以看出,使用两条直线拟合比较合适,可以设计两条分别从(0.0~0.3), (0.3~1.0)的直线,于是就可以得到两个线性模型。那么前面的代码要稍加修改就可以在叶节点生成线性模型而不是常数值。
对于给定的数据集,应该先用线性的模型来对它进行拟合,然后计算真实的目标值与模型预测值之间的差值。最后将这些差值的平方求和就得到了所需的误差。
在regTree.py文件中写入如下代码:
""" 函数说明:线性求解函数 Parameter: dataSet:数据集 Return: ws:线性权重系数 X:自变量X Y:目标变量Y """
def linearSolve(dataSet):
m, n = np.shape(dataSet) #数据集的大小,m行n列
X = np.mat(np.ones((m,n))) #初始化自变量X均为1
Y = np.mat(np.ones((m,1))) #初始化目标变量Y,均为1, 共m行
X[:,1:n] = dataSet[:,0:n-1] #将数据集的前n-1列赋值给X
Y = dataSet[:,-1] #将数据集的最后一列赋值给Y
xTx = X.T*X
if np.linalg.det(xTx) == 0.0:
raise NameError("This matrix is singular, cannot do inverse,\n\
try increasing th second value of ops")
ws = xTx.I * (X.T *Y) #计算线性权重系数ws
return ws, X, Y
""" 函数说明:生成叶节点的模型函数 Parameter: dataSet:数据集 Return: ws:线性权重系数 """
def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)
return ws
""" 函数说明:计算误差函数 Parameter: dataSet:数据集 Return: 平方误差 """
def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(np.power(Y - yHat, 2))
if __name__ == '__main__':
myMat2 = np.mat(loadDataSet('exp2.txt'))
tree = createTree(myMat2, leafType=modelLeaf, errType=modelErr, ops=(1,10))
print(tree)
运行结果如下图所示:
从结果可以看出,该代码以0.285477为界创建了两个模型,生成的两个线性模型分别是y=3.468+1.1852x 和 y=0.0016985+11.96477x,与用于生成该数据的真实模型非常接近。该数据实际是由模型y=3.5+1.0x 和 y=0+12x再加上高斯噪声生成的。
模型树、回归树到底哪一个模型更好?一个比较客观的方法是计算相关系数,也称为R^2^值。该相关系数可以通过调用Numpy库中的命令correcoef(yHat, y, rowvar=0) 来求解,其中的yHat是预测值,y是目标变量的实际值。 |
前面介绍了模型树、回归树和一般的回归方法,下面测试一下哪个模型最好。本节首先给出一些函数,它们可以在树构建好的情况下对给定的输入进行预测,之后利用这些函数来计算三种回归模型的测试误差。这些模型将在某个数据集上进行测试,该数据涉及人的智力水平和自行车的速度的关系。
这里的数据是非线性的,不能简单地使用第8章的全局线性模型建模。当然这里需要声明一下,此数据纯属虚构。
在regTree.py文件中加入如下代码:
""" 函数说明:回归树叶节点的预测 Parameter: model:生成的树 inDat:数据集 Return: 预测值 """
def regTreeEval(model, inDat):
return float(model)
""" 函数说明:模型树叶节点的预测 Parameter: model:生成的树 inDat:数据集 Return: 预测值 """
def modelTreeEval(model, inDat):
n = np.shape(inDat)[1]
X = np.mat(np.ones(1, n+1))
X[:, 1: n+1] = inDat
return float(X*model)
""" 函数说明:对给定树进行预测的函数 Parameter: tree:生成的树 inDat:数据集 modelEval:叶节点预测类型 Return: 预测值 """
def treeForeCast(tree, inData, modelEval=regTreeEval):
if not isTree(tree):
return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):
return treeForeCast(tree['left'], inData, modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForeCast(tree['right'], inData, modelEval)
else:
return modelEval(tree['right'], inData)
""" 函数说明:对给定树在整个测试集进行预测 Parameter: tree:生成的树 testData:测试数据集 modelEval:叶节点预测类型 Return: 预测值 """
def createForeCast(tree, testData, modelEval=regTreeEval):
m = len(testData)
yHat = np.mat(np.zeros((m,1)))
for i in range(m):
yHat[i,0] = treeForeCast(tree, np.mat(testData[i]), modelEval)
return yHat
我们首先看一下数据集bikeSpeedVsIq_train.txt,可视化结果如下图:
然后构建不同的模型,比较各个模型的相关系数。代码如下:
if __name__ == '__main__':
trainMat = np.mat(loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = np.mat(loadDataSet('bikeSpeedVsIq_test.txt'))
myTree1 = createTree(trainMat, ops=(1,20))
yHat1 = createForeCast(myTree1, testMat[:,0])
print("创建回归树,他们的相关系数为:",np.corrcoef(yHat1, testMat[:,1], rowvar=0)[0,1])
myTree2 = createTree(trainMat, leafType=modelLeaf, errType=modelErr, ops=(1,20))
yHat2 = createForeCast(myTree2, testMat[:,0], modelEval=modelTreeEval)
print("创建模型树,他们的相关系数为:",np.corrcoef(yHat2, testMat[:,1], rowvar=0)[0,1])
ws, X, Y = linearSolve(trainMat)
yHat3 = np.mat(np.zeros((np.shape(testMat)[0],1)))
for i in range(np.shape(testMat)[0]):
yHat3[i] = testMat[i,0]*ws[1,0]+ws[0,0]
print("创建线性回归模型,他们的相关系数为:",np.corrcoef(yHat3, testMat[:,1], rowvar=0)[0,1])
运行结果如下图所示:
从运行结果可以看到R2越接近1.0越好,所以可以看出这里模型树的结果比回归树好,比线性回归模型也好。所以,树回归方法在预测复杂数据时会比简单的线性模型更有效。
机器学习给我们提供了一些强大的工具,能从未知数据中抽取出有用的信息。因此,能否将这些信息以易于人们理解的方式呈现十分重要。再者,假如人们可以直接与算法和数据交互,将可以比较轻松地进行解释。如果仅仅只是绘制出一副静态图像,或者只是在python命令行中输出一些数字,那么对结果做分析和交流将非常困难。如果能让用户不需要任何指令就可以按照他们自己的方式来分析数据,就不需要对数据做出过多解释。其中一个能同时支持数据呈现和用户交互的方式就是构建一个图形用户界面(GUI, Graphical user interface).
接下来将介绍如何用python来构建GUI,首先介绍利用一个现有的模块Tkinter来构建GUI, 之后介绍如何在Tkinter和绘图库之间交互,最后通过创建GUI使人们能够自己探索模型树和回归树的奥秘。
python有很多的GUI框架,其中一个易于使用的Tkinter,是随python标准编译版本发布的。Tkinter 可以在 Windows、Mac OS和大多数的 Linux 平台上使用。
特别注意:大写T开头的Tkinter包用于python2, 而小写的tkinter包用于python3,
如果没有注意这个就会报错,说找不到这个包…
MatPlotlib 的构建程序包含一个前端,也就是面向用户的一些代码,如 plot() 和 scatter() 方法等。事实上,它同时创建了一个后端,用于实现绘图和不同应用之间接口。
通过改变后端可以将图像绘制在PNG、PDF、SVG等格式的文件上。下面将设置后端为 TkAgg (Agg 是一个 C++ 的库,可以从图像创建光栅图)。TkAgg可以在所选GUI框架上调用Agg,把 Agg 呈现在画布上。我们可以在Tk的GUI上放置一个画布,并用 .grid()来调整布局。
新建treeExplor.py 文件,写入代码如下:
# -*- coding: utf-8 -*-
import numpy as np
import tkinter as tk
"""大写T开头的Tkinter包用于python2, 而小写的tkinter包用于python3, 如果没有注意这个就会报错,说找不到这个包......"""
import regTrees
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
def reDraw(tolS, tolN):
reDraw.f.clf()
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get():
if tolN<2:
tolN = 2
myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, regTrees.modelErr, (tolS, tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat, regTrees.modelTreeEval)
else:
myTree = regTrees.createTree(reDraw.rawDat, ops=(tolS, tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:,0].flatten().A[0], reDraw.rawDat[:,1].flatten().A[0], s=5)
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
reDraw.canvas.show()
def getInputs():
try: tolN = int(tolNentry.get())
except:
tolN = 10
print("enter Integer for tolN")
tolNentry.delete(0, END)
tolNentry.insert(0,'10')
try: tolS = float(tolSentry.get())
except:
tolS = 1.0
print("enter Float for tolS")
tolSentry.delete(0, END)
tolSentry.insert(0,'1.0')
return tolN, tolS
def drawNewTree():
tolN, tolS = getInputs()
reDraw(tolS, tolN)
root = tk.Tk()
tk.Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
tk.Label(root, text="tolN").grid(row=1, column=0)
tolNentry = tk.Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0, '10')
tk.Label(root, text="tolS").grid(row=2, column=0)
tolSentry = tk.Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
tk.Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = tk.IntVar()
chkBtn = tk.Checkbutton(root, text="Model Tree", variable=chkBtnVar)
reDraw.rawDat = np.mat(regTrees.loadDataSet('sine.txt'))
reDraw.testDat = np.arange(min(reDraw.rawDat[:,0]), max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0,10)
root.mainloop()
运行结果如下:
在上述界面中,可以尝试不同的tolN和tolS值。整个数据集包含200个样本,可以将tolN设为150后观察执行效果。为了构建尽可能大的树,应当将tolN设为1, 将tolS设为0.