这次介绍一下J48的源码,分析J48的源码似乎真还是有用的,同学改造J48写过VFDT,我自己用J48进行特征选择(当然很失败)。
J48的buildClassfier函数:
public void buildClassifier(Instances instances) throws Exception {
ModelSelection modSelection;
if (m_binarySplits)
modSelection = new BinC45ModelSelection(m_minNumObj, instances);
else
modSelection = new C45ModelSelection(m_minNumObj, instances);
if (!m_reducedErrorPruning)
m_root = new C45PruneableClassifierTree(modSelection,
!m_unpruned, m_CF, m_subtreeRaising, !m_noCleanup);
else
m_root = new PruneableClassifierTree(modSelection, !m_unpruned,
m_numFolds, !m_noCleanup, m_Seed);
m_root.buildClassifier(instances);
if (m_binarySplits) {
((BinC45ModelSelection) modSelection).cleanup();
} else {
((C45ModelSelection) modSelection).cleanup();
}
}
在NBTree中已经介绍过了,ModelSelection是决定决策树的模型类,前面两个if,一个是判断连续属性是否只分出两个子结点,另一个判断是否最后剪枝。m_root是一个ClassifierTree对象,它调用buildClassifier函数。这里列出这个函数:
public void buildClassifier(Instances data) throws Exception {
// can classifier tree handle the data?
getCapabilities().testWithFail(data);
// remove instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();
buildTree(data, false);
}
有注释也没什么好说的,直接看最后一个函数buildTree:
public void buildTree(Instances data, boolean keepData) throws Exception {
Instances[] localInstances;
if (keepData) {
m_train = data;
}
m_test = null;
m_isLeaf = false;
m_isEmpty = false;
m_sons = null;
m_localModel = m_toSelectModel.selectModel(data);
if (m_localModel.numSubsets() > 1) {
localInstances = m_localModel.split(data);
data = null;
m_sons = new ClassifierTree[m_localModel.numSubsets()];
for (int i = 0; i < m_sons.length; i++) {
m_sons[i] = getNewTree(localInstances[i]);
localInstances[i] = null;
}
} else {
m_isLeaf = true;
if (Utils.eq(data.sumOfWeights(), 0))
m_isEmpty = true;
data = null;
}
}
这里的selectModel函数,如果看过NBTree一篇的读者应该不会太陌生,selectModel简单地说就是如果不符合分裂的条件就返回NoSplit,如果符合分裂的条件,则从currentModel数组中选出bestModel返回。
这最要注意的是selectModel也不只是决定哪个属性分裂,其实到底如何分裂已经在这个函数里算里出来了。
我把selectModel拆开来讲解
// Check if all Instances belong to one class or if not
// enough Instances to split.
checkDistribution = new Distribution(data);
noSplitModel = new NoSplit(checkDistribution);
if (Utils.sm(checkDistribution.total(), 2 * m_minNoObj)
|| Utils.eq(checkDistribution.total(), checkDistribution
.perClass(checkDistribution.maxClass())))
return noSplitModel;
2 * m_minNoObj表示至有有这么多样本才可以分裂,原因很简单,因为一个结点至少分出两个子结点,每个子结点至少有m_minNoObj个样本,第二个是或条件是表示是否这个结点上所有的样本都属于同一类别,也就是这个结点总的权重是否等于这个最多类别的权重。
// Check if all attributes are nominal and have a lot of values.
if (m_allData != null) {
Enumeration enu = data.enumerateAttributes();
while (enu.hasMoreElements()) {
attribute = (Attribute) enu.nextElement();
if ((attribute.isNumeric())
|| (Utils.sm((double) attribute.numValues(),
(0.3 * (double) m_allData.numInstances())))) {
multiVal = false;
break;
}
}
}
判断是否有很多不同的属性值,标准就是如果有一个属性的属性值小多于总样本数*0.3,那么就是不是multiVal。
currentModel = new C45Split[data.numAttributes()];
sumOfWeights = data.sumOfWeights();
// For each attribute.
for (i = 0; i < data.numAttributes(); i++) {
// Apart from class attribute.
if (i != (data).classIndex()) {
// Get models for current attribute.
currentModel[i] = new C45Split(i, m_minNoObj, sumOfWeights);
currentModel[i].buildClassifier(data);
// Check if useful split for current attribute
// exists and check for enumerated attributes with
// a lot of values.
if (currentModel[i].checkModel())
if (m_allData != null) {
if ((data.attribute(i).isNumeric())
|| (multiVal || Utils.sm((double) data
.attribute(i).numValues(),
(0.3 * (double) m_allData.numInstances())))) {
averageInfoGain = averageInfoGain
+ currentModel[i].infoGain();
validModels++;
}
} else {
averageInfoGain = averageInfoGain
+ currentModel[i].infoGain();
validModels++;
}
} else
currentModel[i] = null;
}
里面重要的两句就是:
// Get models for current attribute.
currentModel[i] = new C45Split(i, m_minNoObj, sumOfWeights);
currentModel[i].buildClassifier(data);
其它的也没有什么,求一下averageInfoGain和validModels。checkModel如果可以分出子结点则为真。
这里是C45Split类的成员函数buildClassfier被调用,列出它的代码:
public void buildClassifier(Instances trainInstances) throws Exception {
// Initialize the remaining instance variables.
m_numSubsets = 0;
m_splitPoint = Double.MAX_VALUE;
m_infoGain = 0;
m_gainRatio = 0;
// Different treatment for enumerated and numeric
// attributes.
if (trainInstances.attribute(m_attIndex).isNominal()) {
m_complexityIndex = trainInstances.attribute(m_attIndex)
.numValues();
m_index = m_complexityIndex;
handleEnumeratedAttribute(trainInstances);
}else{
m_complexityIndex = 2;
m_index = 0;
trainInstances.sort(trainInstances.attribute(m_attIndex));
handleNumericAttribute(trainInstances);
}
}
这里handleEnumerateAttribute和handleNumericAttribute是决定到底是哪一个属性分裂(m_attIndex)和分裂出几个子结点的函数(m_numSubsets)。这里的m_comlexity就是指分可以分裂出多少子结点。如果是连续属性就是2。再看一下handleEnumeratedAttribute函数:
private void handleEnumeratedAttribute(Instances trainInstances)
throws Exception {
Instance instance;
m_distribution = new Distribution(m_complexityIndex,
trainInstances.numClasses());
// Only Instances with known values are relevant.
Enumeration enu = trainInstances.enumerateInstances();
while (enu.hasMoreElements()) {
instance = (Instance) enu.nextElement();
if (!instance.isMissing(m_attIndex))
m_distribution.add((int) instance.value(m_attIndex),
instance);
}
// Check if minimum number of Instances in at least two
// subsets.
if (m_distribution.check(m_minNoObj)) {
m_numSubsets = m_complexityIndex;
m_infoGain = infoGainCrit.splitCritValue(m_distribution,
m_sumOfWeights);
m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,
m_sumOfWeights, m_infoGain);
}
}
// Current attribute is a numeric attribute.
m_distribution = new Distribution(2, trainInstances.numClasses());
// Only Instances with known values are relevant.
Enumeration enu = trainInstances.enumerateInstances();
i = 0;
while (enu.hasMoreElements()) {
instance = (Instance) enu.nextElement();
if (instance.isMissing(m_attIndex))
break;
m_distribution.add(1, instance);
i++;
}
firstMiss = i;
已经讲过了,如果是连续属性就分出两个子结点,也就是Distribution的第一个参数。枚举所有样本,因为在调用HandleNumericAttribute之间已经对数据集根据m_attIndex排序过,所以缺失数据都在最后。也就是firstMiss是在m_attIndex上有确定值的样本个数+1。在while循环中,把所有的样本都先放到bag 1中(add(1,instance))。还是列出来一下吧。
public final void add(int bagIndex, Instance instance) throws Exception {
int classIndex;
double weight;
classIndex = (int) instance.classValue();
weight = instance.weight();
m_perClassPerBag[bagIndex][classIndex] =
m_perClassPerBag[bagIndex][classIndex] + weight;
m_perBag[bagIndex] = m_perBag[bagIndex] + weight;
m_perClass[classIndex] = m_perClass[classIndex] + weight;
totaL = totaL + weight;
}
也就这个函数也就是根据参数bagIndex和样本的类别值classIndex,三个成员变量m_perBag, m_perClass, m_perClassPerBag分别加上样本的权重。
// Compute minimum number of Instances required in each subset.
minSplit = 0.1 * (m_distribution.total())
/ ((double) trainInstances.numClasses());
if (Utils.smOrEq(minSplit, m_minNoObj))
minSplit = m_minNoObj;
else if (Utils.gr(minSplit, 25))
minSplit = 25;
// Enough Instances with known values?
if (Utils.sm((double) firstMiss, 2 * minSplit))
return;
计算分最小分裂需要的样本数,这些涉及的值在Quinlan的论文中没有提到,可能也没有太多的道理,就是如果样本数的1/10小于m_minNoObj那么最小分裂样本数就是m_minNoObj,如果大于25,最小分裂样本数就是25。
如果firstMiss小于2*minSplit表示已经不可以再分裂了(为什么刚才已经讲过了)。
// Compute values of criteria for all possible split indices.
defaultEnt = infoGainCrit.oldEnt(m_distribution);
while (next < firstMiss) {
if (trainInstances.instance(next - 1).value(m_attIndex)
+ 1e-5 < trainInstances.instance(next).value(m_attIndex)) {
// Move class values for all Instances up to next
// possible split point.
m_distribution.shiftRange(1, 0, trainInstances, last, next);
// Check if enough Instances in each subset and compute
// values for criteria.
if (Utils.grOrEq(m_distribution.perBag(0), minSplit)
&& Utils.grOrEq(m_distribution.perBag(1), minSplit)) {
currentInfoGain = infoGainCrit.splitCritValue(
m_distribution, m_sumOfWeights, defaultEnt);
if (Utils.gr(currentInfoGain, m_infoGain)) {
m_infoGain = currentInfoGain;
splitIndex = next - 1;
}
m_index++;
}
last = next;
}
next++;
}
oldEnt计算没有分裂的信息增益,得到defaultEnt注意,刚才是把样本放在了一个bag中。然后对所有有确定值的样本进行循环。第一个if,如果两个属性值太接近,那么选择的分裂点不会有太大的区别,就不进行处理。shiftRange是把第一个bag中下标从last到next-1的样本移到第0个bag。shiftRange代码如下:
public final void shiftRange(int from, int to, Instances source,
int startIndex, int lastPlusOne) throws Exception {
int classIndex;
double weight;
Instance instance;
int i;
for (i = startIndex; i < lastPlusOne; i++) {
instance = (Instance) source.instance(i);
classIndex = (int) instance.classValue();
weight = instance.weight();
m_perClassPerBag[from][classIndex] -= weight;
m_perClassPerBag[to][classIndex] += weight;
m_perBag[from] -= weight;
m_perBag[to] += weight;
}
}
很简单就是把对应样本的样本权重从from bag中减去,再加到to bag中。
转回来,如果bag 1和bag 0都满足最小分裂样本数,计算在当前分裂点上的信息增益值。如果比上一个最好的分裂点的信息增益高,那么记录下当前的信息增益值为最高信息增益值m_infoGain,和当前分裂点splitIndex。
// Was there any useful split?
if (m_index == 0)
return;
// Compute modified information gain for best split.
m_infoGain = m_infoGain - (Utils.log2(m_index) / m_sumOfWeights);
if (Utils.smOrEq(m_infoGain, 0))
return;
// Set instance variables' values to values for best split.
m_numSubsets = 2;
m_splitPoint = (trainInstances.instance(splitIndex + 1).value(
m_attIndex) + trainInstances.instance(splitIndex).value(
m_attIndex)) / 2;
如果没有找到任何分裂点,返回,接下来的m_infoGain自己到J.R.Quinlan的Improved use of continuous Attributes in C4.5论文中的第4页第二段中找。最后设置有两个结点,分裂点在刚才找到的最好的分裂点与下一个属性值的中点。
// In case we have a numerical precision problem we need to choose the
// smaller value
if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(
m_attIndex)) {
m_splitPoint = trainInstances.instance(splitIndex).value(
m_attIndex);
}
// Restore distributioN for best split.
m_distribution = new Distribution(2, trainInstances.numClasses());
m_distribution.addRange(0, trainInstances, 0, splitIndex + 1);
m_distribution.addRange(1, trainInstances, splitIndex + 1, firstMiss);
// Compute modified gain ratio for best split.
m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,
m_sumOfWeights, m_infoGain);
if是处理精度的细节问题。然后重新通过addRange计算m_distribution,最后计算增益率(Gain Ratio)。
这里看到又有一个新类Distribution类,还是要把Distribution类讲一下,Distribution类中有一个bag成员变量,它的意思是能有几个子结点。从下面的构造函数看出来的,第一个参数在上面调用它的时候用的就是m_complexityIndex.
public Distribution(int numBags, int numClasses) {
int i;
m_perClassPerBag = new double[numBags][0];
m_perBag = new double[numBags];
m_perClass = new double[numClasses];
for (i = 0; i < numBags; i++)
m_perClassPerBag[i] = new double[numClasses];
totaL = 0;
}
Distribution的add函数就是在相应的属性值上进行统计,太简单了,略过。
回到刚才的buildTree函数,如果numSubsets返回1,则表示当前结点不再分裂为叶子结点,如果大于1,那么调用split函数,split函数只是根据有上次得到的子结点数,并根据WhichSubset返回值,把当前结点的样本分到几个子结点去。再对每一个子结点训练一个新子树,到这已经与以前讲的ID3有很大的相似了。
可能大家学习的时候都对理论很感兴趣,但看了半天也没看到,有点不解,其实也很好找,当然应该在handleEnumerateAttribute和handleNumericAttribute中了,也就是InfoGainSplitCrit和GainRatioSplitCrit两个类。
分裂一个样本与NBTree相似,这里不再赘述。