GBDT回归(决策树桩) python实现
1.代码
import pandas as pd
import numpy as np
def load_data(filename):
data = pd.read_table(filename, sep ='\t', header = None)
X_data = data.loc[:][0]
Y_data = data.loc[:][1]
return X_data, Y_data
class weakLearner():
def __init__(self, min_sample, min_err):
self.min_sample = min_sample
self.min_err = min_err
def __mse(self, left_node):
return np.sum((np.average(left_node) - np.array(left_node)) ** 2)
def Get_stump_list(self, X):
tmp1 = list(X.copy())
tmp2 = list(X.copy())
tmp1.insert(0, 0)
tmp2.append(0)
stump_list = ((np.array(tmp1) + np.array(tmp2)) / float(2))[1:-1]
return stump_list
def __binSplitData(self, stump_list, X, Y):
left_node = []
left_node_x = []
right_node = []
right_node_y = []
for j in range(np.shape(X)[0]):
if X[j] < stump_list:
left_node.append(Y[j])
right_node
else:
right_node.append(Y[j])
return left_node, right_node
def __bestSplit(self, stump_list, X, Y):
best_mse = np.inf
for i in range(np.shape(stump_list)[0]):
left_node, right_node = self.__binSplitData(stump_list[i], X, Y)
left_mse = self.__mse(left_node)
right_mse = self.__mse(right_node)
if best_mse > (left_mse + right_mse):
best_mse = left_mse + right_mse
best_f_val = stump_list[i]
return best_f_val
def __CART(self, X, Y):
tree = dict()
if len(X) <= self.min_sample:
return np.mean(Y)
stump_list = self.Get_stump_list(X)
best_f_val = self.__bestSplit(stump_list, X, Y)
tree['cut_f_val'] = best_f_val
left_node, right_node = self.__binSplitData(best_f_val, X, Y)
tree['left_tree'] = left_node
tree['right_tree'] = right_node
left_mse = self.__mse(left_node)
right_mse = self.__mse(right_node)
tree['left_mse'] = left_mse
tree['right_mse'] = right_mse
now_mse = left_mse + right_mse
if now_mse <= self.min_err:
return np.mean(Y)
return tree
def train(self, X, Y):
self.tree = self.__CART(X, Y)
return self.tree
def predict(self, X):
return np.array([self.__predict_one(x, self.tree) for x in X])
def __predict_one(self, x, tree):
cut_val = tree['cut_f_val']
Y_left = np.average(tree['left_tree'])
Y_right = np.average(tree['right_tree'])
result = Y_left if x <= cut_val else Y_right
return result
class GBDT():
def __init__(self, n_estimators: int = 10, classifier = weakLearner):
self.n_estimators = n_estimators
self.weakLearner = classifier
self.Trees = []
self.learn_rate = 1
self.init_value = None
def get_init_value(self, Y):
"""
:param y: 样本标签列表
:return: average(float)"样本标签的平均值
"""
average = sum(Y) / len(Y)
return average
def get_residuals(self, Y, y_hat):
y_residuals = []
for i in range(len(Y)):
y_residuals.append(Y[i] - y_hat[i])
return y_residuals
def fit(self, X, Y, n_estimators, learn_rate, min_sample, min_err):
self.n_estimators = n_estimators
self.learn_rate = learn_rate
X = np.array(X)
Y = np.array(Y)
self.init_value = self.get_init_value(Y)
n = len(Y)
y_hat = [self.init_value] * n
residual = self.get_residuals(Y, y_hat)
self.Trees = []
for num in range(self.n_estimators):
wl = self.weakLearner(min_sample, min_err)
Tree = wl.train(X, residual)
Y_left = np.average(Tree['left_tree'])
Y_right = np.average(Tree['right_tree'])
left_residual = self.get_residuals(np.array(Tree['left_tree']), [Y_left] * n)
right_residual = self.get_residuals(np.array(Tree['right_tree']), [Y_right] * n)
residual = np.append(left_residual, right_residual)
for i in range(n):
y_hat[i] = y_hat[i] + self.learn_rate * residual[i]
residual = self.get_residuals(Y, y_hat)
self.Trees.append(wl)
return self.Trees
def predict(self, X):
M = self.n_estimators
y_ = 0
for m in range(M):
y_ += [self.init_value] * len(X) + self.learn_rate * self.Trees[m].predict(X)
return y_
def error(self, Y, y_predict):
error = np.square(Y - y_predict).sum() / len(Y)
return error
if __name__ == "__main__":
print("-----------------------------------1.load data------------------------------------------")
X, Y = load_data("./sine.txt")
X_train = X[0:150]
Y_train = Y[0:150]
X_test = X[150: 200]
Y_test = Y[150: 200]
print("-----------------------------------2. Parameters Setting--------------------------------")
n_estimators = 4
learn_rate = 0.5
min_sample = 30
min_err = 0.3
print("----------------------------------3.build GBDT--------------------------------------------")
Trees_reg = GBDT()
Trees = Trees_reg.fit(X_train, Y_train, n_estimators, learn_rate, min_sample, min_err)
print("----------------------------------4.Predict Result-----------------------------------------")
y_predict = Trees_reg.predict(X_test)
print("Y_test: ", np.mat(Y_test))
print("predict_results: ", y_predict)
print("----------------------------------5.Predict error-------------------------------------------")
error = Trees_reg.error(Y_test, y_predict)
print("The error is ", error)
2.数据集 sine.txt
0.190350 0.878049
0.306657 -0.109413
0.017568 0.030917
0.122328 0.951109
0.076274 0.774632
0.614127 -0.250042
0.220722 0.807741
0.089430 0.840491
0.278817 0.342210
0.520287 -0.950301
0.726976 0.852224
0.180485 1.141859
0.801524 1.012061
0.474273 -1.311226
0.345116 -0.319911
0.981951 -0.374203
0.127349 1.039361
0.757120 1.040152
0.345419 -0.429760
0.314532 -0.075762
0.250828 0.657169
0.431255 -0.905443
0.386669 -0.508875
0.143794 0.844105
0.470839 -0.951757
0.093065 0.785034
0.205377 0.715400
0.083329 0.853025
0.243475 0.699252
0.062389 0.567589
0.764116 0.834931
0.018287 0.199875
0.973603 -0.359748
0.458826 -1.113178
0.511200 -1.082561
0.712587 0.615108
0.464745 -0.835752
0.984328 -0.332495
0.414291 -0.808822
0.799551 1.072052
0.499037 -0.924499
0.966757 -0.191643
0.756594 0.991844
0.444938 -0.969528
0.410167 -0.773426
0.532335 -0.631770
0.343909 -0.313313
0.854302 0.719307
0.846882 0.916509
0.740758 1.009525
0.150668 0.832433
0.177606 0.893017
0.445289 -0.898242
0.734653 0.787282
0.559488 -0.663482
0.232311 0.499122
0.934435 -0.121533
0.219089 0.823206
0.636525 0.053113
0.307605 0.027500
0.713198 0.693978
0.116343 1.242458
0.680737 0.368910
0.484730 -0.891940
0.929408 0.234913
0.008507 0.103505
0.872161 0.816191
0.755530 0.985723
0.620671 0.026417
0.472260 -0.967451
0.257488 0.630100
0.130654 1.025693
0.512333 -0.884296
0.747710 0.849468
0.669948 0.413745
0.644856 0.253455
0.894206 0.482933
0.820471 0.899981
0.790796 0.922645
0.010729 0.032106
0.846777 0.768675
0.349175 -0.322929
0.453662 -0.957712
0.624017 -0.169913
0.211074 0.869840
0.062555 0.607180
0.739709 0.859793
0.985896 -0.433632
0.782088 0.976380
0.642561 0.147023
0.779007 0.913765
0.185631 1.021408
0.525250 -0.706217
0.236802 0.564723
0.440958 -0.993781
0.397580 -0.708189
0.823146 0.860086
0.370173 -0.649231
0.791675 1.162927
0.456647 -0.956843
0.113350 0.850107
0.351074 -0.306095
0.182684 0.825728
0.914034 0.305636
0.751486 0.898875
0.216572 0.974637
0.013273 0.062439
0.469726 -1.226188
0.060676 0.599451
0.776310 0.902315
0.061648 0.464446
0.714077 0.947507
0.559264 -0.715111
0.121876 0.791703
0.330586 -0.165819
0.662909 0.379236
0.785142 0.967030
0.161352 0.979553
0.985215 -0.317699
0.457734 -0.890725
0.171574 0.963749
0.334277 -0.266228
0.501065 -0.910313
0.988736 -0.476222
0.659242 0.218365
0.359861 -0.338734
0.790434 0.843387
0.462458 -0.911647
0.823012 0.813427
0.594668 -0.603016
0.498207 -0.878847
0.574882 -0.419598
0.570048 -0.442087
0.331570 -0.347567
0.195407 0.822284
0.814327 0.974355
0.641925 0.073217
0.238778 0.657767
0.400138 -0.715598
0.670479 0.469662
0.069076 0.680958
0.294373 0.145767
0.025628 0.179822
0.697772 0.506253
0.729626 0.786519
0.293071 0.259997
0.531802 -1.095833
0.487338 -1.034481
0.215780 0.933506
0.625818 0.103845
0.179389 0.892237
0.192552 0.915516
0.671661 0.330361
0.952391 -0.060263
0.795133 0.945157
0.950494 -0.071855
0.194894 1.000860
0.351460 -0.227946
0.863456 0.648456
0.945221 -0.045667
0.779840 0.979954
0.996606 -0.450501
0.632184 -0.036506
0.790898 0.994890
0.022503 0.386394
0.318983 -0.152749
0.369633 -0.423960
0.157300 0.962858
0.153223 0.882873
0.360068 -0.653742
0.433917 -0.872498
0.133461 0.879002
0.757252 1.123667
0.309391 -0.102064
0.195586 0.925339
0.240259 0.689117
0.340591 -0.455040
0.243436 0.415760
0.612755 -0.180844
0.089407 0.723702
0.469695 -0.987859
0.943560 -0.097303
0.177241 0.918082
0.317756 -0.222902
0.515337 -0.733668
0.344773 -0.256893
0.537029 -0.797272
0.626878 0.048719
0.208940 0.836531
0.470697 -1.080283
0.054448 0.624676
0.109230 0.816921
0.158325 1.044485
0.976650 -0.309060
0.643441 0.267336
0.215841 1.018817
0.905337 0.409871
0.154354 0.920009
0.947922 -0.112378
0.201391 0.768894
3.实验结果(精确度不高,下次考虑更深的树)
![在这里插入图片描述](http://img.e-com-net.com/image/info8/c092e66c27644b4e94f8c5873c84f60c.jpg)