水平有限,随便翻译,能看就行
原址:http://weka.wikispaces.com/Generating+cross-validation+folds+%28Java+approach%29
文章主要讲述了如何直接使用wekaAPI产生交叉验证训练/测试splits
使用如下变量:
Instances data = ...; // contains the full dataset we wann create train/test sets from int seed = ...; // the seed for randomizing the data int folds = ...; // the number of folds to generate, >=2
首先随机化你的数据
Random rand = new Random(seed); // create seeded number generator randData = new Instances(data); // create copy of original data randData.randomize(rand); // randomize data with number generator
randData.stratify(folds);
产生折叠
单步运行
接下来我们要做的事情就是创建训练和测试集
for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // further processing, classification, etc. ... }
Instances train = randData.trainCV(folds, n, rand);
上面的例子演示了单步运行交叉验证,如果你想运行10次10折交叉验证使用下面的循环:
Instances data = ...; // our dataset again, obtained from somewhere int runs = 10; for (int i = 0; i < runs; i++) { seed = i+1; // every run gets a new, but defined seed value // see: randomize the data ... // see: generate the folds ... }
CrossValidationSingleRun.java
import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; import weka.core.Utils; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import java.util.Random; /** * Performs a single run of cross-validation. * * Command-line parameters: * <ul> * <li>-t filename - the dataset to use</li> * <li>-x int - the number of folds to use</li> * <li>-s int - the seed for the random number generator</li> * <li>-c int - the class index, "first" and "last" are accepted as well; * "last" is used by default</li> * <li>-W classifier - classname and options, enclosed by double quotes; * the classifier to cross-validate</li> * </ul> * * Example command-line: * <pre> * java CrossValidationSingleRun -t anneal.arff -c last -x 10 -s 1 -W "weka.classifiers.trees.J48 -C 0.25" * </pre> * * @author FracPete (fracpete at waikato dot ac dot nz) */ public class CrossValidationSingleRun { /** * Performs the cross-validation. See Javadoc of class for information * on command-line parameters. * * @param args the command-line parameters * @throws Excecption if something goes wrong */ public static void main(String[] args) throws Exception { // loads data and set class index Instances data = DataSource.read(Utils.getOption("t", args)); String clsIndex = Utils.getOption("c", args); if (clsIndex.length() == 0) clsIndex = "last"; if (clsIndex.equals("first")) data.setClassIndex(0); else if (clsIndex.equals("last")) data.setClassIndex(data.numAttributes() - 1); else data.setClassIndex(Integer.parseInt(clsIndex) - 1); // classifier String[] tmpOptions; String classname; tmpOptions = Utils.splitOptions(Utils.getOption("W", args)); classname = tmpOptions[0]; tmpOptions[0] = ""; Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions); // other options int seed = Integer.parseInt(Utils.getOption("s", args)); int folds = Integer.parseInt(Utils.getOption("x", args)); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) randData.stratify(folds); // perform cross-validation Evaluation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(cls); clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); } // output evaluation System.out.println(); System.out.println("=== Setup ==="); System.out.println("Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions())); System.out.println("Dataset: " + data.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + seed); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); } }
import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; import weka.core.Utils; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import java.util.Random; /** * Performs a single run of cross-validation. Outputs the Confusion matrices * for each single fold. * * Command-line parameters: * <ul> * <li>-t filename - the dataset to use</li> * <li>-x int - the number of folds to use</li> * <li>-s int - the seed for the random number generator</li> * <li>-c int - the class index, "first" and "last" are accepted as well; * "last" is used by default</li> * <li>-W classifier - classname and options, enclosed by double quotes; * the classifier to cross-validate</li> * </ul> * * Example command-line: * <pre> * java CrossValidationSingleRun -t anneal.arff -c last -x 10 -s 1 -W "weka.classifiers.trees.J48 -C 0.25" * </pre> * * @author FracPete (fracpete at waikato dot ac dot nz) */ public class CrossValidationSingleRunVariant { /** * Performs the cross-validation. See Javadoc of class for information * on command-line parameters. * * @param args the command-line parameters * @throws Excecption if something goes wrong */ public static void main(String[] args) throws Exception { // loads data and set class index Instances data = DataSource.read(Utils.getOption("t", args)); String clsIndex = Utils.getOption("c", args); if (clsIndex.length() == 0) clsIndex = "last"; if (clsIndex.equals("first")) data.setClassIndex(0); else if (clsIndex.equals("last")) data.setClassIndex(data.numAttributes() - 1); else data.setClassIndex(Integer.parseInt(clsIndex) - 1); // classifier String[] tmpOptions; String classname; tmpOptions = Utils.splitOptions(Utils.getOption("W", args)); classname = tmpOptions[0]; tmpOptions[0] = ""; Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions); // other options int seed = Integer.parseInt(Utils.getOption("s", args)); int folds = Integer.parseInt(Utils.getOption("x", args)); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) randData.stratify(folds); // perform cross-validation System.out.println(); System.out.println("=== Setup ==="); System.out.println("Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions())); System.out.println("Dataset: " + data.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + seed); System.out.println(); Evaluation evalAll = new Evaluation(randData); for (int n = 0; n < folds; n++) { Evaluation eval = new Evaluation(randData); Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(cls); clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); evalAll.evaluateModel(clsCopy, test); // output evaluation System.out.println(); System.out.println(eval.toMatrixString("=== Confusion matrix for fold " + (n+1) + "/" + folds + " ===\n")); } // output evaluation System.out.println(); System.out.println(evalAll.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); } }
import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; import weka.core.Utils; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import java.util.Random; /** * Performs multiple runs of cross-validation. * * Command-line parameters: * <ul> * <li>-t filename - the dataset to use</li> * <li>-x int - the number of folds to use</li> * <li>-r int - the number of runs to perform</li> * <li>-c int - the class index, "first" and "last" are accepted as well; * "last" is used by default</li> * <li>-W classifier - classname and options, enclosed by double quotes; * the classifier to cross-validate</li> * </ul> * * Example command-line: * <pre> * java CrossValidationMultipleRuns -t labor.arff -c last -x 10 -r 10 -W "weka.classifiers.trees.J48 -C 0.25" * </pre> * * @author FracPete (fracpete at waikato dot ac dot nz) */ public class CrossValidationMultipleRuns { /** * Performs the cross-validation. See Javadoc of class for information * on command-line parameters. * * @param args the command-line parameters * @throws Excecption if something goes wrong */ public static void main(String[] args) throws Exception { // loads data and set class index Instances data = DataSource.read(Utils.getOption("t", args)); String clsIndex = Utils.getOption("c", args); if (clsIndex.length() == 0) clsIndex = "last"; if (clsIndex.equals("first")) data.setClassIndex(0); else if (clsIndex.equals("last")) data.setClassIndex(data.numAttributes() - 1); else data.setClassIndex(Integer.parseInt(clsIndex) - 1); // classifier String[] tmpOptions; String classname; tmpOptions = Utils.splitOptions(Utils.getOption("W", args)); classname = tmpOptions[0]; tmpOptions[0] = ""; Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions); // other options int runs = Integer.parseInt(Utils.getOption("r", args)); int folds = Integer.parseInt(Utils.getOption("x", args)); // perform cross-validation for (int i = 0; i < runs; i++) { // randomize data int seed = i + 1; Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) randData.stratify(folds); Evaluation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(cls); clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); } // output evaluation System.out.println(); System.out.println("=== Setup run " + (i+1) + " ==="); System.out.println("Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions())); System.out.println("Dataset: " + data.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + seed); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i+1) + "===", false)); } } }
import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; import weka.core.converters.ConverterUtils.DataSink; import weka.core.Utils; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.filters.Filter; import weka.filters.supervised.attribute.AddClassification; import java.util.Random; /** * Performs a single run of cross-validation and adds the prediction on the * test set to the dataset. * * Command-line parameters: * <ul> * <li>-t filename - the dataset to use</li> * <li>-o filename - the output file to store dataset with the predictions * in</li> * <li>-x int - the number of folds to use</li> * <li>-s int - the seed for the random number generator</li> * <li>-c int - the class index, "first" and "last" are accepted as well; * "last" is used by default</li> * <li>-W classifier - classname and options, enclosed by double quotes; * the classifier to cross-validate</li> * </ul> * * Example command-line: * <pre> * java CrossValidationAddPrediction -t anneal.arff -c last -o predictions.arff -x 10 -s 1 -W "weka.classifiers.trees.J48 -C 0.25" * </pre> * * @author FracPete (fracpete at waikato dot ac dot nz) */ public class CrossValidationAddPrediction { /** * Performs the cross-validation. See Javadoc of class for information * on command-line parameters. * * @param args the command-line parameters * @throws Excecption if something goes wrong */ public static void main(String[] args) throws Exception { // loads data and set class index Instances data = DataSource.read(Utils.getOption("t", args)); String clsIndex = Utils.getOption("c", args); if (clsIndex.length() == 0) clsIndex = "last"; if (clsIndex.equals("first")) data.setClassIndex(0); else if (clsIndex.equals("last")) data.setClassIndex(data.numAttributes() - 1); else data.setClassIndex(Integer.parseInt(clsIndex) - 1); // classifier String[] tmpOptions; String classname; tmpOptions = Utils.splitOptions(Utils.getOption("W", args)); classname = tmpOptions[0]; tmpOptions[0] = ""; Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions); // other options int seed = Integer.parseInt(Utils.getOption("s", args)); int folds = Integer.parseInt(Utils.getOption("x", args)); // randomize data Random rand = new Random(seed); Instances randData = new Instances(data); randData.randomize(rand); if (randData.classAttribute().isNominal()) randData.stratify(folds); // perform cross-validation and add predictions Instances predictedData = null; Evaluation eval = new Evaluation(randData); for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // the above code is used by the StratifiedRemoveFolds filter, the // code below by the Explorer/Experimenter: // Instances train = randData.trainCV(folds, n, rand); // build and evaluate classifier Classifier clsCopy = Classifier.makeCopy(cls); clsCopy.buildClassifier(train); eval.evaluateModel(clsCopy, test); // add predictions AddClassification filter = new AddClassification(); filter.setClassifier(cls); filter.setOutputClassification(true); filter.setOutputDistribution(true); filter.setOutputErrorFlag(true); filter.setInputFormat(train); Filter.useFilter(train, filter); // trains the classifier Instances pred = Filter.useFilter(test, filter); // perform predictions on test set if (predictedData == null) predictedData = new Instances(pred, 0); for (int j = 0; j < pred.numInstances(); j++) predictedData.add(pred.instance(j)); } // output evaluation System.out.println(); System.out.println("=== Setup ==="); System.out.println("Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions())); System.out.println("Dataset: " + data.relationName()); System.out.println("Folds: " + folds); System.out.println("Seed: " + seed); System.out.println(); System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false)); // output "enriched" dataset DataSink.write(Utils.getOption("o", args), predictedData); } }