博主最近迷上了打怪物猎人,这片文章拖了很久才开始动笔
一、算法
AdditiveRegression,换个更出名一点的叫法可以称作GBDT(Grandient Boosting Decision Tree)梯度下降分类树,或者GBRT(Grandient Boosting Regression Tree)梯度下降回归树,是一种多分类器组合的算法,更确切的说,是属于Boosting算法。
谈到Boosting算法,就不能不提AdaBoost,参见之前我写的博客,可以看到AdaBoost的核心是级联分类器,使后一级分类器更加“关注”较为容易分错的数据,即后一级的分类器更有在易出错的数据集上进行训练。。
而GBDT作为Boosting算法,也是将多分类器进行级联训练,后一级的分类器则更多关注前面所有分类器预测结果与实际结果的残差,在这个残差上训练新的分类器,最终预测时将残差级联相加。
关于GBDT相关算法的公式推导可参考:
http://en.wikipedia.org/wiki/Gradient_boosting#Gradient_tree_boosting
http://www.360doc.com/content/12/0428/15/5874309_207282768.shtml
扯了这么多,下面简单说一下算法训练流程。
(1)输入训练集Data和基分类器的数量N
(2)使用训练集Data训练第1个基分类器
(3)for (int i=2;i<N;i++)
(4)使用前i-1个分类器进行预测,计算预测结果和训练数据的残差
(5)如果残差小于某个阈值,则退出循环。
(5)使用此残差训练第i个分类器
(6)转(3)
预测流程:
(1)根据输入数据,计算N个分类器的预测结果。
(2)将预测结果相加并返回。
可以看到,GBDT从原理上来讲并不复杂,“残差”的概念就用梯度来进行标示,抓住这一个线索看懂Wiki中的推导公式也并不是难事。复杂的是“如何证明其有效性”,这远超过本文可论证的范畴。
二、源码实现
就像之前所有的分类器一样,依然从buildClassifier入手。
(1)buildClassifier
public void buildClassifier(Instances data) throws Exception { super.buildClassifier(data); //additiveRegerssion只支持数值型数据。 getCapabilities().testWithFail(data); //如果训练数据的class列为空,则去掉 Instances newData = new Instances(data); newData.deleteWithMissingClass(); double sum = 0; double temp_sum = 0; // 第一个分类器使用ZeroR,也就是预测的值是训练值的均值,没有使用基分类器(默认的基分类器为weka.classifiers.trees.DecisionStump(),也就是单层决策树(决策桩) m_zeroR = new ZeroR(); m_zeroR.buildClassifier(newData); // 如果只有一列,则没法训练 if (newData.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_SuitableData = false; return; } else { m_SuitableData = true; } //这个residualReplace函数会将数据集用某个分类器进行分类后,再将其class列替换为残差,这个稍后详细分析一下。 newData = residualReplace(newData, m_zeroR, false); for (int i = 0; i < newData.numInstances(); i++) { sum += newData.instance(i).weight() * newData.instance(i).classValue() * newData.instance(i).classValue();//这里计算了加权的残差平方和 } if (m_Debug) { System.err.println("Sum of squared residuals " +"(predicting the mean) : " + sum); } m_NumIterationsPerformed = 0; do { temp_sum = sum; // Build the classifier m_Classifiers[m_NumIterationsPerformed].buildClassifier(newData);//在新的数据集上训练,注意新的数据集的class已经替换为残差了,体现了gradient boosting思想 newData = residualReplace(newData, m_Classifiers[m_NumIterationsPerformed], true);//再重新替换为残差 sum = 0; for (int i = 0; i < newData.numInstances(); i++) { sum += newData.instance(i).weight() * newData.instance(i).classValue() * newData.instance(i).classValue();//重新计算残差平方和 } if (m_Debug) { System.err.println("Sum of squared residuals : "+sum); } m_NumIterationsPerformed++; } while (((temp_sum - sum) > Utils.SMALL) && (m_NumIterationsPerformed < m_Classifiers.length));//退出条件有2个,第一个是两次迭代残差平方没有明显变化,第二个是已训练完所有分类器。 }
算法思想很简单,代码也很直观。
下面分析一下residualReplace函数。
(2)residualReplace
private Instances residualReplace(Instances data, Classifier c, boolean useShrinkage) throws Exception { double pred,residual; Instances newInst = new Instances(data); for (int i = 0; i < newInst.numInstances(); i++) { pred = c.classifyInstance(newInst.instance(i)); //进行预测 if (useShrinkage) { pred *= getShrinkage();//使用shrinkage来防止过拟合 } residual = newInst.instance(i).classValue() - pred;//算出残差 newInst.instance(i).setClassValue(residual);//原始数据的class用残差替换 } // System.err.print(newInst); return newInst; }
shrinkage(缩减)的思想认为,每次走一小步逐渐逼近结果的效果,要比每次迈一大步很快逼近结果的方式更容易避免过拟合。即它不完全信任每一个棵残差树,它认为每棵树只学到了真理的一小部分,累加的时候只累加一小部分,通过多学几棵树弥补不足。(转自http://blog.csdn.net/w28971023/article/details/8240756)
可以看到,残差本身可以理解成“希望分类器结果前进的向量”,也就是梯度的含义,即包含了方向(分类器往哪个方向调整),也包含了长度(调整多少)。而shrinkage就是缩小这个长度到一定的比值,如10%,这样每次在这个向量方向上前进10%,以此来防止过拟合。
为什么shrinkage能防止过拟合?这又是一个看上去就复杂的不得了的问题啊。。。。
(3)classifyInstance
public double classifyInstance(Instance inst) throws Exception { double prediction = m_zeroR.classifyInstance(inst); if (!m_SuitableData) { return prediction; } for (int i = 0; i < m_NumIterationsPerformed; i++) { double toAdd = m_Classifiers[i].classifyInstance(inst); toAdd *= getShrinkage(); prediction += toAdd; } return prediction; }
四、总结
如果非要写个什么总结的话,那么我希望是以下几点:
(1)gbdt思想简单,实现起来也简单,效果非常理想。
(2)weka的additiveRegression是一个gbrt的简单实现,只能处理数值型数据。
(3)其实现的核心逻辑是用残差替换原有数据集的class列。
(4)可以选择性的使用shrinkage来防止过拟合。