最近一直在学习《机器学习实战》这本书。感觉写的挺好,并且在网上能够轻易的找到python源码。对学习机器学习很有帮助。
最近学到第九章树回归。发现代码中一再出现问题。在网上查了下,一般的网上流行的错误有两处。但是我发现源码中的错误不止这两处,还有个错误在prune里面,另外模型树的预测部分也写的很挫,奇怪的是这本书之前的代码基本上都没有犯过什么错误,这一章的代码却频繁的出现各种问题,让人匪夷所思。。
首先是说明书中已经证实的两个错误,都是简单的语法错误
第一个错误在代码的binSplitDataSet函数中
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
return mat0,mat1
这里dataSet[nonzero(dataSet[:,feature] > value)[0],:]确实是把矩阵dataset切分为了两个矩阵,可是画蛇添足之处在于后面加了[0],这就代表两个矩阵都返回了矩阵的第一行。自然是错的。。改法很简单,删掉[0]即可,如下:
def binSplitDataSet(dataMat,feature,value):
mat0=dataMat[nonzero(dataMat[:,feature]>value)[0]]
mat1=dataMat[nonzero(dataMat[:,feature]<=value)[0]]
return mat0,mat1
紧接着第二个错误在chooseBestSplit里(58行)
for splitVal in set(dataSet[:,featIndex]):
这里for splitVal in set(dataSet[:,featIndex]):,set传入参数是一个矩阵,这里肯定会报语法错误,应该改成
for splitValue in set(dataMat[:,feat].T.tolist()[0]):
接下来的错误是我自己觉得的。网上并没有看到别的出处:
代码的getmean函数
def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left']+tree['right'])/2.0
仔细观察这个函数。其实并没有对这个树求得平均数
比如当tree={‘left’:3,’right’:{‘left’:1,’right’:2}}的时候
调用getMean将返回2.25,显然不等于(1+2+3)/3=2
所以这个代码要改就复杂了,我的方法是连建树的部分一起改了。让每个节点(叶子节点除外)都包含了一个值代表这个节点下面的叶节点数量,并且还在这个节点上面记录所有叶节点的和。这样计算getMean的时候效率也会更高(复杂度变成O(1))
接下来看treeForeCast函数
代码124行
if inData[tree['spInd']] > tree['spVal']:
这里又是一个坑爹之处。明显spInd代表的是列,这里填入矩阵就变成行了。而且一行矩阵怎么可能和一个数字比较大小。所以这里必然应该要改成:
if float(inMat[:,tree['spInd']])>tree['spVal']:
然后继续看modelTreeEval函数。
这个函数也写的不忍吐槽。。
def modelTreeEval(model, inDat):
n = shape(inDat)[1]
X = mat(ones((1,n+1)))
X[:,1:n+1]=inDat
return float(X*model)
这里model只是一个2行1列的矩阵,你照着书上的写法,最后return的部分肯定是报错的,两个矩阵根本不能相乘。。
改:
def modelTreeEval(model,inMat):
n=inMat.shape[1]
X=mat(ones((1,n)))
X[:,1:n]=inMat[:,:-1]
return float(X*model)
暂时就发现这么多错误,后面的画图的部分我就没看了。
我发一下改完后的全部代码(run部分的代码为自己写的测试函数,只测了模型树的预测)
# -*- coding:utf-8 -*-
import math
from numpy import *
import matplotlib.pyplot as plt
def loadDataSet(fileName):
fr=open(fileName)
dataSet=[]
for line in fr.readlines():
items=line.strip().split('\t')
dataSet.append(map(float,items))
return dataSet
def regLeaf(dataMat):
return mean(dataMat[:,-1])
def regErr(dataMat):
return var(dataMat[:,-1])*dataMat.shape[0]
def modelLeaf(dataMat):
ws,X,Y=linearSolve(dataMat)
return ws
def modelErr(dataMat):
ws,X,Y=linearSolve(dataMat)
YHat=X*ws
return sum(power(YHat-Y,2))
def binSplitDataSet(dataMat,feature,value):
mat0=dataMat[nonzero(dataMat[:,feature]>value)[0]]
mat1=dataMat[nonzero(dataMat[:,feature]<=value)[0]]
return mat0,mat1
def chooseBestFeature(dataMat,leafType,errType,ops):
tolS=ops[0];tolN=ops[1]
if len(set(dataMat[:,-1].T.tolist()[0]))==1:
return None,leafType(dataMat)
m,n=shape(dataMat);S=errType(dataMat)
bestS=inf;bestVal=0;bestFeature=0
for feat in range(n-1):
for splitValue in set(dataMat[:,feat].T.tolist()[0]):
mat0,mat1=binSplitDataSet(dataMat,feat,splitValue)
if (mat0.shape[0]or (mat1.shape[0]continue
nowErr=errType(mat0)+errType(mat1)
if nowErrif abs(S-bestS)return None,leafType(dataMat)
mat0,mat1=binSplitDataSet(dataMat,bestFeature,bestVal)
if (mat0.shape[0]or (mat1.shape[0]return None,leafType(dataMat)
return bestFeature,bestVal
def isTree(obj):
return (type(obj).__name__=='dict')
def createTree(dataMat,leafType=modelLeaf,errType=modelErr,ops=(1,4)):
feat,val=chooseBestFeature(dataMat,leafType,errType,ops)
if feat==None:
return val
retTree={}
retTree['spInd']=feat
retTree['spVal']=val
leftMat,rightMat=binSplitDataSet(dataMat,feat,val)
retTree['lTree']=createTree(leftMat,leafType,errType,ops)
retTree['rTree']=createTree(rightMat,leafType,errType,ops)
# 建树的时候计算出每个节点下面的叶子节点数量,并且计算出该节点下面的叶子节点的和
# 方便后剪枝的时候能够快速的对树进行塌陷处理
# 此处改动已经和原书中的写法有了很大不同
if isTree(retTree['lTree']) and isTree(retTree['rTree']):
retTree['leafN']=retTree['lTree']['leafN']+retTree['rTree']['leafN']
retTree['total']=retTree['lTree']['total']+retTree['rTree']['total']
elif (not isTree(retTree['lTree'])) and isTree(retTree['rTree']):
retTree['leafN']=1+retTree['rTree']['leafN']
retTree['total']=retTree['lTree']+retTree['rTree']['total']
elif isTree(retTree['lTree']) and (not isTree(retTree['rTree'])):
retTree['leafN']=retTree['lTree']['leafN']+1
retTree['total']=retTree['lTree']['total']+retTree['rTree']
else:
retTree['leafN']=2
retTree['total']=retTree['lTree']+retTree['rTree']
return retTree
def getMean(tree):
if isTree(tree):
if isTree(tree['lTree']):
tree['lTree']=tree['lTree']['total']
if isTree(tree['rTree']):
tree['rTree']=tree['rTree']['total']
return tree['total']*1.0/tree['leafN']
else:
return tree
def prune(tree,testData):
if testData.shape[0]==0:
return getMean(tree)
if isTree(tree['lTree']) or isTree(tree['rTree']):
lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
if isTree(tree['lTree']):
tree['lTree']=prune(tree['lTree'],lSet)
if isTree(tree['rTree']):
tree['rTree']=prune(tree['rTree'],rSet)
if not isTree(tree['lTree']) and not isTree(tree['rTree']):
lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
errNoMerge=sum(power(lSet[:,-1]-tree['lTree'],2))+sum(power(rSet[:,-1]-tree['rTree'],2))
treeMean=tree['total']/tree['leafN']
errMerge=sum(power(testData[:,-1]-treeMean,2))
if errNoMergeprint "merging"
return treeMean
else:
return tree
else:
return tree
def linearSolve(dataMat):
m,n=shape(dataMat)
X=mat(ones((m,n)));Y=mat(zeros((m,1)))
X[:,1:n]=dataMat[:,0:n-1];Y=dataMat[:,-1]
xTx=X.T*X
if linalg.det(xTx)==0.0:
raise NameError("singular matrix")
ws=xTx.I*X.T*Y
return ws,X,Y
# 回归树预测
def regTreeEval(model,inMat):
return float(model)
# 模型树预测
def modelTreeEval(model,inMat):
n=inMat.shape[1]
X=mat(ones((1,n)))
X[:,1:n]=inMat[:,:-1]
return float(X*model)
def treeForeCast(tree,inMat,modelEval=modelTreeEval):
if not isTree(tree):
return modelEval(tree,inMat)
if float(inMat[:,tree['spInd']])>tree['spVal']:
if not isTree(tree['lTree']):
return modelEval(tree['lTree'],inMat)
else:
return treeForeCast(tree['lTree'],inMat,modelEval)
else:
if not isTree(tree['rTree']):
return modelEval(tree['rTree'],inMat)
else:
return treeForeCast(tree['rTree'],inMat,modelEval)
def createForeCast(tree,testMat,modelEval=modelTreeEval):
m=testMat.shape[0]
yHat=mat(zeros((m,1)))
for i in range(m):
yHat[i]=treeForeCast(tree,testMat[i],modelEval)
return yHat
def run():
dataSet=loadDataSet('bikeSpeedVsIq_train.txt')
testSet=loadDataSet('bikeSpeedVsIq_test.txt')
tree=createTree(mat(dataSet),ops=(1,20))
yHat=createForeCast(tree,mat(testSet))
print corrcoef(yHat.T,mat(testSet)[:,1].T)
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(array(dataSet)[:,0],array(dataSet)[:,1],c='cyan',marker='o')
plt.show()
run()