统计学习方法中GDBT简单实现

模型:加法模型、每个基学习器为CART回归树桩
损失函数:平方误差
迭代停止条件:基学习器数达到上限、或整体误差低于设定值

import numpy as np
import math
import matplotlib.pyplot as plt
# 准备数据
x = np.arange(1,11,1)
threshold = np.linspace(1.5,9.5,num=9)
y = np.array([5.56,5.70,5.91,6.40,6.80,7.05,8.90,8.70,9.00,9.05])
# 误差计算函数
def clacErr(label):
    return np.var(label)*label.shape[0]
# 划分数据集
def binSplitDataSet(x,y,splitVal):
    dataSet1 = y[x<splitVal]
    dataSet2 = y[x>=splitVal]
    return dataSet1,dataSet2
# 更新y值
def updateYValue(x,y,res):
    y[x<res['index']] -= res['leftVal']
    y[x>=res['index']] -= res['rightVal']
    return y
# 计算每一轮的残差
def clacLossError(x,y,res):
    tmp = np.array([0.0]*x.shape[0])
    for dictVal in res:
        tmp[x<dictVal['index']] = tmp[x<dictVal['index']] + dictVal['leftVal']
        tmp[x>=dictVal['index']] = tmp[x>=dictVal['index']] + dictVal['rightVal']
    resdiual = [math.pow(val,2) for val in tmp-y]
    return round(np.sum(resdiual),2)
def simpleGDBT(x,y,maxIteration = 10,minLoss = 0.20):
    iteration = 0;finalTree = []; x_data = x.copy();y_data = y.copy(); value = np.inf
    while iteration < maxIteration and value > 0.2:
        bestError = np.inf;res = {}
        for val in threshold:
            dataSet1,dataSet2 = binSplitDataSet(x_data,y_data,val)
            errorAfterSplit = clacErr(dataSet1) + clacErr(dataSet2)
            if errorAfterSplit < bestError:
                bestError = errorAfterSplit
                bestIndex = val
                leftVal = round(np.mean(dataSet1),2)
                rightVal = round(np.mean(dataSet2),2)
        res["index"] = bestIndex
        res["leftVal"] = leftVal
        res["rightVal"] = rightVal
        y_data = updateYValue(x_data,y_data,res)
        iteration += 1
        finalTree.append(res)
        value = clacLossError(x,y,finalTree)
        print('iteration: %s, loss is: %s' %(iteration,value))
    return finalTree
finalTree = simpleGDBT(x,y)
print(finalTree)

暂不涉及原理推导、后续有时间会继续更新详细的GDBT树代码实现

你可能感兴趣的:(python,机器学习)