机器学习实战 第九章回归树错误

最近一直在学习《机器学习实战》这本书。感觉写的挺好,并且在网上能够轻易的找到python源码。对学习机器学习很有帮助。

最近学到第九章树回归。发现代码中一再出现问题。在网上查了下,一般的网上流行的错误有两处。但是我发现源码中的错误不止这两处,还有个错误在prune里面,另外模型树的预测部分也写的很挫,奇怪的是这本书之前的代码基本上都没有犯过什么错误,这一章的代码却频繁的出现各种问题,让人匪夷所思。。

首先是说明书中已经证实的两个错误,都是简单的语法错误

第一个错误在代码的binSplitDataSet函数中

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
    return mat0,mat1

这里dataSet[nonzero(dataSet[:,feature] > value)[0],:]确实是把矩阵dataset切分为了两个矩阵,可是画蛇添足之处在于后面加了[0],这就代表两个矩阵都返回了矩阵的第一行。自然是错的。。改法很简单,删掉[0]即可,如下:

def binSplitDataSet(dataMat,feature,value):
    mat0=dataMat[nonzero(dataMat[:,feature]>value)[0]]
    mat1=dataMat[nonzero(dataMat[:,feature]<=value)[0]]
    return mat0,mat1

紧接着第二个错误在chooseBestSplit里(58行)

for splitVal in set(dataSet[:,featIndex]):

这里for splitVal in set(dataSet[:,featIndex]):,set传入参数是一个矩阵,这里肯定会报语法错误,应该改成

for splitValue in set(dataMat[:,feat].T.tolist()[0]):

接下来的错误是我自己觉得的。网上并没有看到别的出处:
代码的getmean函数

def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

仔细观察这个函数。其实并没有对这个树求得平均数
比如当tree={‘left’:3,’right’:{‘left’:1,’right’:2}}的时候
调用getMean将返回2.25,显然不等于(1+2+3)/3=2
所以这个代码要改就复杂了,我的方法是连建树的部分一起改了。让每个节点(叶子节点除外)都包含了一个值代表这个节点下面的叶节点数量,并且还在这个节点上面记录所有叶节点的和。这样计算getMean的时候效率也会更高(复杂度变成O(1))

接下来看treeForeCast函数
代码124行

if inData[tree['spInd']] > tree['spVal']:

这里又是一个坑爹之处。明显spInd代表的是列,这里填入矩阵就变成行了。而且一行矩阵怎么可能和一个数字比较大小。所以这里必然应该要改成:

if float(inMat[:,tree['spInd']])>tree['spVal']:

然后继续看modelTreeEval函数。
这个函数也写的不忍吐槽。。

def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1,n+1)))
    X[:,1:n+1]=inDat
    return float(X*model)

这里model只是一个2行1列的矩阵,你照着书上的写法,最后return的部分肯定是报错的,两个矩阵根本不能相乘。。

改:

def modelTreeEval(model,inMat):
    n=inMat.shape[1]
    X=mat(ones((1,n)))
    X[:,1:n]=inMat[:,:-1]
    return float(X*model)

暂时就发现这么多错误,后面的画图的部分我就没看了。
我发一下改完后的全部代码(run部分的代码为自己写的测试函数,只测了模型树的预测)

# -*- coding:utf-8 -*-
import math
from numpy import *
import matplotlib.pyplot as plt

def loadDataSet(fileName):
    fr=open(fileName)
    dataSet=[]
    for line in fr.readlines():
        items=line.strip().split('\t')
        dataSet.append(map(float,items))
    return dataSet

def regLeaf(dataMat):
    return mean(dataMat[:,-1])

def regErr(dataMat):
    return var(dataMat[:,-1])*dataMat.shape[0]

def modelLeaf(dataMat):
    ws,X,Y=linearSolve(dataMat)
    return ws

def modelErr(dataMat):
    ws,X,Y=linearSolve(dataMat)
    YHat=X*ws
    return sum(power(YHat-Y,2))

def binSplitDataSet(dataMat,feature,value):
    mat0=dataMat[nonzero(dataMat[:,feature]>value)[0]]
    mat1=dataMat[nonzero(dataMat[:,feature]<=value)[0]]
    return mat0,mat1

def chooseBestFeature(dataMat,leafType,errType,ops):
    tolS=ops[0];tolN=ops[1]
    if len(set(dataMat[:,-1].T.tolist()[0]))==1:
        return None,leafType(dataMat)
    m,n=shape(dataMat);S=errType(dataMat)
    bestS=inf;bestVal=0;bestFeature=0
    for feat in range(n-1):
        for splitValue in set(dataMat[:,feat].T.tolist()[0]):
            mat0,mat1=binSplitDataSet(dataMat,feat,splitValue)
            if (mat0.shape[0]or (mat1.shape[0]continue
            nowErr=errType(mat0)+errType(mat1)
            if nowErrif abs(S-bestS)return None,leafType(dataMat)
    mat0,mat1=binSplitDataSet(dataMat,bestFeature,bestVal)
    if (mat0.shape[0]or (mat1.shape[0]return None,leafType(dataMat)
    return bestFeature,bestVal

def isTree(obj):
    return (type(obj).__name__=='dict')

def createTree(dataMat,leafType=modelLeaf,errType=modelErr,ops=(1,4)):
    feat,val=chooseBestFeature(dataMat,leafType,errType,ops)
    if feat==None:
        return val
    retTree={}
    retTree['spInd']=feat
    retTree['spVal']=val
    leftMat,rightMat=binSplitDataSet(dataMat,feat,val)
    retTree['lTree']=createTree(leftMat,leafType,errType,ops)
    retTree['rTree']=createTree(rightMat,leafType,errType,ops)
    # 建树的时候计算出每个节点下面的叶子节点数量,并且计算出该节点下面的叶子节点的和
    # 方便后剪枝的时候能够快速的对树进行塌陷处理
    # 此处改动已经和原书中的写法有了很大不同
    if isTree(retTree['lTree']) and isTree(retTree['rTree']):
        retTree['leafN']=retTree['lTree']['leafN']+retTree['rTree']['leafN']
        retTree['total']=retTree['lTree']['total']+retTree['rTree']['total']
    elif (not isTree(retTree['lTree'])) and isTree(retTree['rTree']):
        retTree['leafN']=1+retTree['rTree']['leafN']
        retTree['total']=retTree['lTree']+retTree['rTree']['total']
    elif isTree(retTree['lTree']) and (not isTree(retTree['rTree'])):
        retTree['leafN']=retTree['lTree']['leafN']+1
        retTree['total']=retTree['lTree']['total']+retTree['rTree']
    else:
        retTree['leafN']=2
        retTree['total']=retTree['lTree']+retTree['rTree']
    return retTree

def getMean(tree):
    if isTree(tree):
        if isTree(tree['lTree']):
            tree['lTree']=tree['lTree']['total']
        if isTree(tree['rTree']):
            tree['rTree']=tree['rTree']['total']
        return tree['total']*1.0/tree['leafN']
    else:
        return tree

def prune(tree,testData):
    if testData.shape[0]==0:
        return getMean(tree)
    if isTree(tree['lTree']) or isTree(tree['rTree']):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
    if isTree(tree['lTree']):
        tree['lTree']=prune(tree['lTree'],lSet)
    if isTree(tree['rTree']):
        tree['rTree']=prune(tree['rTree'],rSet)
    if not isTree(tree['lTree']) and not isTree(tree['rTree']):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
        errNoMerge=sum(power(lSet[:,-1]-tree['lTree'],2))+sum(power(rSet[:,-1]-tree['rTree'],2))
        treeMean=tree['total']/tree['leafN']
        errMerge=sum(power(testData[:,-1]-treeMean,2))
        if errNoMergeprint "merging"
            return treeMean
        else:
            return tree
    else:
        return tree

def linearSolve(dataMat):
    m,n=shape(dataMat)
    X=mat(ones((m,n)));Y=mat(zeros((m,1)))
    X[:,1:n]=dataMat[:,0:n-1];Y=dataMat[:,-1]
    xTx=X.T*X
    if linalg.det(xTx)==0.0:
        raise NameError("singular matrix")
    ws=xTx.I*X.T*Y
    return ws,X,Y

# 回归树预测
def regTreeEval(model,inMat):
    return float(model)

# 模型树预测
def modelTreeEval(model,inMat):
    n=inMat.shape[1]
    X=mat(ones((1,n)))
    X[:,1:n]=inMat[:,:-1]
    return float(X*model)

def treeForeCast(tree,inMat,modelEval=modelTreeEval):
    if not isTree(tree):
        return modelEval(tree,inMat)
    if float(inMat[:,tree['spInd']])>tree['spVal']:
        if not isTree(tree['lTree']):
            return modelEval(tree['lTree'],inMat)
        else:
            return treeForeCast(tree['lTree'],inMat,modelEval)
    else:
        if not isTree(tree['rTree']):
            return modelEval(tree['rTree'],inMat)
        else:
            return treeForeCast(tree['rTree'],inMat,modelEval)

def createForeCast(tree,testMat,modelEval=modelTreeEval):
    m=testMat.shape[0]
    yHat=mat(zeros((m,1)))
    for i in range(m):
        yHat[i]=treeForeCast(tree,testMat[i],modelEval)
    return yHat

def run():
    dataSet=loadDataSet('bikeSpeedVsIq_train.txt')
    testSet=loadDataSet('bikeSpeedVsIq_test.txt')
    tree=createTree(mat(dataSet),ops=(1,20))
    yHat=createForeCast(tree,mat(testSet))
    print corrcoef(yHat.T,mat(testSet)[:,1].T)
    fig=plt.figure()
    ax=fig.add_subplot(111)
    ax.scatter(array(dataSet)[:,0],array(dataSet)[:,1],c='cyan',marker='o')
    plt.show()

run()

最后运行的结果
机器学习实战 第九章回归树错误_第1张图片

然后最屌的就是,虽然书中的代码错误一大堆,居然最后的答案还跟我是一样的。这才是最骚的。。。
这里写图片描述

收工!
机器学习实战 第九章回归树错误_第2张图片

你可能感兴趣的:(机器学习)