weka进行十折交叉验证

十折交叉验证英文名叫做10-fold cross-validation,用来测试算法准确性。是常用的测试方法。将数据集分成十分,轮流将其中9份作为训练数据,1份作为测试数据,进行试验。每次试验都会得出相应的正确率(或差错率)。10次的结果的正确率(或差错率)的平均值作为对算法精度的估计,一般还需要进行多次10折交叉验证(例如10次10折交叉验证),再求其均值,作为对算法准确性的估计。之所以选择将数据集分为10份,是因为通过利用大量数据集、使用不同学习技术进行的大量试验,表明10折是获得最好误差估计的恰当选择,而且也有一些理论根据可以证明这一点。但这并非最终诊断,争议仍然存在。而且似乎5折或者20折与10折所得出的结果也相差无几。

以上来自百度百科http://baike.baidu.com/view/1213887.htm

 

其实做法很简单,就是将一份数据集首先进行随机打散,然后均分成10份,再把其中的每一份拿出来作为测试样本,其余的9份作为训练样本,这样就会得到10个训练样本和10个测试样本。

代码:

 

 
  
  1. Instances data = new Instances(instances);  
  2.         data.randomize(new Random());  
  3.         if (data.classAttribute().isNominal()) {  
  4.               data.stratify(numFolds);  
  5.           } 

首先创建了一份新的instances data,然后调用randomize进行数据集的打散,同时要判断下类属性是不是nominal类型的,想知道为什么要加这个判断,你可以去试试看如果违反了这个条件会有什么样的结果。stratify函数是用来分层的,numFolds就代表着几折,numfolds=10就意味着将数据分成10份。

 

 
  
  1. Instances train = data.trainCV(numFolds, i);  
  2.          Instances test = data.testCV(numFolds, i); 

trainCV方法用来获取训练样本

testCV方法用来获取测试样本

你只要对for i=0->10这么一个循环就能获取10个配对的训练集和测试集

 

下面的内容是来自weka官网的API文档,其中trainCV还能够传递一个random用来对获取的训练样本进行随机打散。

 

trainCV

public Instances trainCV(int numFolds,
                         int numFold)
Creates the training set for one fold of a cross-validation on the dataset.

 

Parameters:
numFolds - the number of folds in the cross-validation. Must be greater than 1.
numFold - 0 for the first fold, 1 for the second, ...
Returns:
the training set
Throws:
java.lang.IllegalArgumentException - if the number of folds is less than 2 or greater than the number of instances.

trainCV

public Instances trainCV(int numFolds,
                         int numFold,
                         java.util.Random random)
Creates the training set for one fold of a cross-validation on the dataset. The data is subsequently randomized based on the given random number generator.

 

Parameters:
numFolds - the number of folds in the cross-validation. Must be greater than 1.
numFold - 0 for the first fold, 1 for the second, ...
random - the random number generator
Returns:
the training set
Throws:
java.lang.IllegalArgumentException - if the number of folds is less than 2 or greater than the number of instances.

testCV

public Instances testCV(int numFolds,
                        int numFold)
Creates the test set for one fold of a cross-validation on the dataset.

 

Parameters:
numFolds - the number of folds in the cross-validation. Must be greater than 1.
numFold - 0 for the first fold, 1 for the second, ...
Returns:
the test set as a set of weighted instances
Throws:
java.lang.IllegalArgumentException - if the number of folds is less than 2 or greater than the number of instances.

本文出自 “Never Stop Sharing” 博客,请务必保留此出处http://loma1990.blog.51cto.com/6082839/1060949

你可能感兴趣的:(Data,Mining)