Weka算法Classifier-tree-J48源码分析(三)ModelSelection


ModelSelection主要是用于选择合适的列对数据集进行分割,结合上一篇J48的主流程,发现用到的ModelSelection有 C45ModelSelection以及BinC45ModelSelection,先来分析C45ModelSelection。


一、C45ModelSelection

首先作为一个ModelSelection接口,实现的主要方法有两个,分别是selectModel(Instances)和selectionModel(Instances,Instances)。C45ModelSelection的后一个方法如下:

  public final ClassifierSplitModel selectModel(Instances train, Instances test) {

    return selectModel(train);
  }
可以看到就是忽略了test测试集直接调用selectModel方法而已,因此主要分词selectModel方法。

先放出整段代码,然后对该段代码进行分析:

public final ClassifierSplitModel selectModel(Instances data){

    double minResult;
    double currentResult;
    C45Split [] currentModel;
    C45Split bestModel = null;
    NoSplit noSplitModel = null;
    double averageInfoGain = 0;
    int validModels = 0;
    boolean multiVal = true;
    Distribution checkDistribution;
    Attribute attribute;
    double sumOfWeights;
    int i;
    
    try{

      // Check if all Instances belong to one class or if not
      // enough Instances to split.
      checkDistribution = new Distribution(data);
      noSplitModel = new NoSplit(checkDistribution);
      if (Utils.sm(checkDistribution.total(),2*m_minNoObj) ||
	  Utils.eq(checkDistribution.total(),
		   checkDistribution.perClass(checkDistribution.maxClass())))
	return noSplitModel;

      // Check if all attributes are nominal and have a 
      // lot of values.
      if (m_allData != null) {
	Enumeration enu = data.enumerateAttributes();
	while (enu.hasMoreElements()) {
	  attribute = (Attribute) enu.nextElement();
	  if ((attribute.isNumeric()) ||
	      (Utils.sm((double)attribute.numValues(),
			(0.3*(double)m_allData.numInstances())))){
	    multiVal = false;
	    break;
	  }
	}
      } 

      currentModel = new C45Split[data.numAttributes()];
      sumOfWeights = data.sumOfWeights();

      // For each attribute.
      for (i = 0; i < data.numAttributes(); i++){
	
	// Apart from class attribute.
	if (i != (data).classIndex()){
	  
	  // Get models for current attribute.
	  currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights);
	  currentModel[i].buildClassifier(data);
	  
	  // Check if useful split for current attribute
	  // exists and check for enumerated attributes with 
	  // a lot of values.
	  if (currentModel[i].checkModel())
	    if (m_allData != null) {
	      if ((data.attribute(i).isNumeric()) ||
		  (multiVal || Utils.sm((double)data.attribute(i).numValues(),
					(0.3*(double)m_allData.numInstances())))){
		averageInfoGain = averageInfoGain+currentModel[i].infoGain();
		validModels++;
	      } 
	    } else {
	      averageInfoGain = averageInfoGain+currentModel[i].infoGain();
	      validModels++;
	    }
	}else
	  currentModel[i] = null;
      }
      
      // Check if any useful split was found.
      if (validModels == 0)
	return noSplitModel;
      averageInfoGain = averageInfoGain/(double)validModels;

      // Find "best" attribute to split on.
      minResult = 0;
      for (i=0;i<data.numAttributes();i++){
	if ((i != (data).classIndex()) &&
	    (currentModel[i].checkModel()))
	  
	  // Use 1E-3 here to get a closer approximation to the original
	  // implementation.
	  if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) &&
	      Utils.gr(currentModel[i].gainRatio(),minResult)){ 
	    bestModel = currentModel[i];
	    minResult = currentModel[i].gainRatio();
	  } 
      }

      // Check if useful split was found.
      if (Utils.eq(minResult,0))
	return noSplitModel;
      
      // Add all Instances with unknown values for the corresponding
      // attribute to the distribution for the model, so that
      // the complete distribution is stored with the model. 
      bestModel.distribution().
	  addInstWithUnknown(data,bestModel.attIndex());
      
      // Set the split point analogue to C45 if attribute numeric.
      if (m_allData != null)
	bestModel.setSplitPoint(m_allData);
      return bestModel;
    }catch(Exception e){
      e.printStackTrace();
    }
    return null;
  }
第一部分,主要是对局部变量的一些定义。

    double minResult;//最小的信息增益率
    double currentResult;//当前信息增益率
    C45Split [] currentModel;//存放所有未分类属性产生的模型
    C45Split bestModel = null;//目前为止的最好模型
    NoSplit noSplitModel = null;//代表不用分的模型
    double averageInfoGain = 0;//各模型(currentModel)的平均信息增益
    int validModels = 0;//是否存在有效模型
    boolean multiVal = true;//是否多值
    Distribution checkDistribution;//训练数据集的分布
    Attribute attribute;//属性列集合
    double sumOfWeights;//训练数据集的weight的和
    int i;//循环变量

第二部分,递归出口。

 checkDistribution = new Distribution(data);
      noSplitModel = new NoSplit(checkDistribution);
      if (Utils.sm(checkDistribution.total(),2*m_minNoObj) ||
	  Utils.eq(checkDistribution.total(),
		   checkDistribution.perClass(checkDistribution.maxClass())))
	return noSplitModel;
可以看到,如果当前数据集数量小于2*m_minNoObj(这个值默认是2),或者当前数据集已经全在同一个分类中,就返回noSplitModel代表不用分,这就是整个C45分类树节点停止分裂的条件。

第三部分,判断是否是多值:

      if (m_allData != null) {
	Enumeration enu = data.enumerateAttributes();
	while (enu.hasMoreElements()) {
	  attribute = (Attribute) enu.nextElement();
	  if ((attribute.isNumeric()) ||
	      (Utils.sm((double)attribute.numValues(),
			(0.3*(double)m_allData.numInstances())))){
	    multiVal = false;
	    break;
	  }
	}
      } 
如果属性中,任意一列是数值型,或者其取值的数量小于训练集数量*0.3,则不是多值,否则按多值处理。是否是多值影响到后面某些逻辑。

第四部分,对于每一列属性构造Spliter。

    for (i = 0; i < data.numAttributes(); i++){
	
	// Apart from class attribute.
	if (i != (data).classIndex()){
	  
	  // Get models for current attribute.
	  currentModel[i] = new C45Split(i,m_minNoObj,sumOfWeights);
	  currentModel[i].buildClassifier(data);
	  
	  // Check if useful split for current attribute
	  // exists and check for enumerated attributes with 
	  // a lot of values.
	  if (currentModel[i].checkModel())
	    if (m_allData != null) {
	      if ((data.attribute(i).isNumeric()) ||
		  (multiVal || Utils.sm((double)data.attribute(i).numValues(),
					(0.3*(double)m_allData.numInstances())))){
		averageInfoGain = averageInfoGain+currentModel[i].infoGain();
		validModels++;
	      } 
	    } else {
	      averageInfoGain = averageInfoGain+currentModel[i].infoGain();
	      validModels++;
	    }
	}else
	  currentModel[i] = null;
      }

对于每一列属性,如果不是存放分类的值得话,则构造C45Split对象,在该对象上进行分类,然后算出信息增益,相加到averageInfoGain上。对于C45Split的构造,稍后再看。

第五部分,选出最优模型。

 if (validModels == 0)
	return noSplitModel;
      averageInfoGain = averageInfoGain/(double)validModels;

      // Find "best" attribute to split on.
      minResult = 0;
      for (i=0;i<data.numAttributes();i++){
	if ((i != (data).classIndex()) &&
	    (currentModel[i].checkModel()))
	  
	  // Use 1E-3 here to get a closer approximation to the original
	  // implementation.
	  if ((currentModel[i].infoGain() >= (averageInfoGain-1E-3)) &&
	      Utils.gr(currentModel[i].gainRatio(),minResult)){ 
	    bestModel = currentModel[i];
	    minResult = currentModel[i].gainRatio();
	  } 

如果存在有效模型,则选出有效模型。注意这个选出最优模型的逻辑,并不是单纯的选出gainRatio最大的,而是在基础上必须还要大于平均信息增益,这也是和传统的c45算法不一样的一点。

从上述过程来看,Weka在实现C45的时候做了一个小的变动,并没有从“还没有使用的”属性列中找出最合理的列最为分割属性,而是在“所有的列”中找出最合理的列作为分割属性,虽然这二者在结果上肯定是等价的(之前是有过的属性不和能有很好的信息增益率),但效率上个人对Weka的做法持保留意见。


二、C45Spliter

在ModelSelection中真正根据属性对训练集进行分割、计算信息增益和信息增益率的是C45Spliter,首先也从其buildClassifier方法入手进行分析。

public void buildClassifier(Instances trainInstances) 
       throws Exception {

    // Initialize the remaining instance variables.
    m_numSubsets = 0;
    m_splitPoint = Double.MAX_VALUE;
    m_infoGain = 0;
    m_gainRatio = 0;

    // Different treatment for enumerated and numeric
    // attributes.
    if (trainInstances.attribute(m_attIndex).isNominal()) {
      m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
      m_index = m_complexityIndex;
      handleEnumeratedAttribute(trainInstances);
    }else{
      m_complexityIndex = 2;
      m_index = 0;
      trainInstances.sort(trainInstances.attribute(m_attIndex));
      handleNumericAttribute(trainInstances);
    }
  }    
可以看到,对于枚举型和数值型的属性是分开处理的,枚举型调用handlEnumeratedAttribute,数值型调用handleNumericAttribute,值得注意的是,在处理数值型之前,按照相应列进行排序,同时设置m_complexityIndex也就是期望分裂的节点数设定为2。

首先来看枚举类型是如何处理的。

private void handleEnumeratedAttribute(Instances trainInstances)
       throws Exception {
    
    Instance instance;

    m_distribution = new Distribution(m_complexityIndex,
			      trainInstances.numClasses());
    
    // Only Instances with known values are relevant.
    Enumeration enu = trainInstances.enumerateInstances();
    while (enu.hasMoreElements()) {
      instance = (Instance) enu.nextElement();
      if (!instance.isMissing(m_attIndex))
	m_distribution.add((int)instance.value(m_attIndex),instance);
    }
    
    // Check if minimum number of Instances in at least two
    // subsets.
    if (m_distribution.check(m_minNoObj)) {
      m_numSubsets = m_complexityIndex;
      m_infoGain = infoGainCrit.
	splitCritValue(m_distribution,m_sumOfWeights);
      m_gainRatio = 
	gainRatioCrit.splitCritValue(m_distribution,m_sumOfWeights,
				     m_infoGain);
    }
  }
大概流程是新建一个分布,遍历所有instance,如果该instance对应的分裂的属性不为空的话,则放到不同的bag里,之后检查一下这个分布是否满足要求,要求就是最多允许有一个bag里的数据数量小于m_minNoObj,如果通过检查,就设置subset的数量,计算信息增益和信息增益率,否则subset默认会是0,上层调用checkModel就会返回false代表这是一个无效模型。

接下来看数值型是如何处理的:

 private void handleNumericAttribute(Instances trainInstances)
       throws Exception {
  
    int firstMiss;//最后一个有效instance的下标
    int next = 1;//下一个instance的index
    int last = 0;//当前instance的index
    int splitIndex = -1;//分裂点
    double currentInfoGain;//当前信息增益
    double defaultEnt;//分割之前的信息熵
    double minSplit;
    Instance instance;
    int i;

//首先新建一个分布,数值型默认处理为2维分布,也就可以理解为小于某个值放到一个Bag里,其余的放到另外一个Bag里
    m_distribution = new Distribution(2,trainInstances.numClasses());
    Enumeration enu = trainInstances.enumerateInstances();
    i = 0;
<pre name="code" class="cpp">//注意instances传入的时候是排好序的,这个排序保证了missingValue放在最后面,所以读到了missingValue其之后肯定都是miss//ingValue,换言之,firstMiss在循环之后代表了最后一个有效的instance的下标。
while (enu.hasMoreElements()) { instance = (Instance) enu.nextElement(); if (instance.isMissing(m_attIndex))break; m_distribution.add(1,instance); i++; } firstMiss = i;//循环结束后,m_distribution里放入了所有的有效instance,并全放入了bag1里。

 
 
//minSplit是最后分类好每个Bag里最小的数据的量,也就是0.1*每个类的均值。
    minSplit =  0.1*(m_distribution.total())/
      ((double)trainInstances.numClasses());
    if (Utils.smOrEq(minSplit,m_minNoObj)) 
      minSplit = m_minNoObj;
    else
      if (Utils.gr(minSplit,25)) 
	minSplit = 25;
	
//如果有效数据总量不到2*minSplit,换言之无论怎么分均不能保证2个bag里的数量大于minSplit,就直接返回。
    if (Utils.sm((double)firstMiss,2*minSplit))
      return;
    
//defaultEnt代表旧的信息熵,也就是对该属性进行分类之前,Indexclass对应的信息熵。
    defaultEnt = infoGainCrit.oldEnt(m_distribution);
    while (next < firstMiss) {
	  
      if (trainInstances.instance(next-1).value(m_attIndex)+1e-5 < 
	  trainInstances.instance(next).value(m_attIndex)) { 
	<pre name="code" class="cpp">//Instances里的记录是升序排列的,加上这个条件默认把值相差很小的Instance就当做同一个instance处理了
//last代表当前,next代表下一个,默认next=1,last=0,所以shiftRange可以理解成把当前记录从bag1移动到bag0中
<span style="font-family: Arial, Helvetica, sans-serif;">//注意一开始初始化时候所有的都是在bag1里面的。	</span>
m_distribution.shiftRange(1,0,trainInstances,last,next);if (Utils.grOrEq(m_distribution.perBag(0),minSplit) && //如果两个bag都满足最小数据集的数量minSplit Utils.grOrEq(m_distribution.perBag(1),minSplit)) { currentInfoGain = infoGainCrit. splitCritValue(m_distribution,m_sumOfWeights, //算一下信息增益 defaultEnt);
 
 
	  if (Utils.gr(currentInfoGain,m_infoGain)) {
	    m_infoGain = currentInfoGain;//如果信息增益比当前最大的要大,则替换当前最大的值,并记录splitIndex
	    splitIndex = next-1;
	  }
	  m_index++;
	}
	last = next;
      }
      next++;
    }
    
    if (m_index == 0)
      return; //执行到这里说明没找到一个合适的分裂点,直接返回。
    
    // 计算最佳信息增益
    m_infoGain = m_infoGain-(Utils.log2(m_index)/m_sumOfWeights);
    if (Utils.smOrEq(m_infoGain,0))
      return; //如果信息增益是0也说明没找到合适的分裂点,直接返回。
    
    //剩下的就是根据分裂点进行属性的划分。
    m_numSubsets = 2;
    m_splitPoint = 
      (trainInstances.instance(splitIndex+1).value(m_attIndex)+
       trainInstances.instance(splitIndex).value(m_attIndex))/2;

    // In case we have a numerical precision problem we need to choose the
    // smaller value
    if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) {
      m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex);
    }

    // Restore distributioN for best split.
    m_distribution = new Distribution(2,trainInstances.numClasses());
    m_distribution.addRange(0,trainInstances,0,splitIndex+1);
    m_distribution.addRange(1,trainInstances,splitIndex+1,firstMiss);

    // Compute modified gain ratio for best split.
    m_gainRatio = gainRatioCrit.
      splitCritValue(m_distribution,m_sumOfWeights,
		     m_infoGain);
  }
这个函数有点复杂,具体逻辑也写到代码注释里了。


三、BinC45ModelSelection

该函数只负责生成二元分类树的模型,selectModel方法和C45ModelSelection几乎一样,不在多说,不同点在于其使用BinC45Spliter而不是C45Spliter。


四、BinC45Spliter

 handleNumericAttribute对于数值类型的属性处理和C45Spliter完全一样。下面只分析一下handleEnumeratedAttribute。

 private void handleEnumeratedAttribute(Instances trainInstances)
       throws Exception {
    
    Distribution newDistribution,secondDistribution;
    int numAttValues;
    double currIG,currGR;
    Instance instance;
    int i;

    numAttValues = trainInstances.attribute(m_attIndex).numValues();
    newDistribution = new Distribution(numAttValues,
				       trainInstances.numClasses());
    
    // Only Instances with known values are relevant.
    Enumeration enu = trainInstances.enumerateInstances();
    while (enu.hasMoreElements()) {
      instance = (Instance) enu.nextElement();
      if (!instance.isMissing(m_attIndex))
	newDistribution.add((int)instance.value(m_attIndex),instance);
    }
    m_distribution = newDistribution;

    // For all values
    for (i = 0; i < numAttValues; i++){

      if (Utils.grOrEq(newDistribution.perBag(i),m_minNoObj)){
	secondDistribution = new Distribution(newDistribution,i);
	
	// Check if minimum number of Instances in the two
	// subsets.
	if (secondDistribution.check(m_minNoObj)){
	  m_numSubsets = 2;
	  currIG = m_infoGainCrit.splitCritValue(secondDistribution,
					       m_sumOfWeights);
	  currGR = m_gainRatioCrit.splitCritValue(secondDistribution,
						m_sumOfWeights,
						currIG);
	  if ((i == 0) || Utils.gr(currGR,m_gainRatio)){
	    m_gainRatio = currGR;
	    m_infoGain = currIG;
	    m_splitPoint = (double)i;
	    m_distribution = secondDistribution;
	  }
	}
      }
    }
可以看出,上一段代码根据该属性的不同的取值,在已有分布基础上,建立一个新的分布secondeDistribution,
secondDistribution = new Distribution(newDistribution,i);
该分布包含两列,属性下标为i的,其余的,在这个分布的基础上计算信息增益和信息增益率,并选出最优的。

换句话说,离散值分类的二元化处理就是选出其中一列当做一个branch,其余的当做另外一个branch。虽然从结构上来讲这肯定不是最优的选择,但简单易用就够了。


到这里基本分析完了J48的两个ModelSelection,下一篇文章将对classifierInstance过程进行分析,并给出一个简单的总结。






你可能感兴趣的:(源码,算法,weka,C4.5,分类器)