Encog入门学习(二)

继续的Encog-quickstart学习,这一次要实现一个从CSV 文件中读取数据,调用feed forward方式的神经网络进行分类的问题。数据集在https://archive.ics.uci.edu/ml/datasets/Iris 中,数据格式如下:


数据集中并没有列名,数据之间使用逗号隔开,通过4个数值型的属性来预测当前样本究竟是哪一种Iris。让我们先来看看代码:

import java.io.File;
import java.util.Arrays;

import org.encog.ConsoleStatusReportable;
import org.encog.Encog;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.versatile.NormalizationHelper;
import org.encog.ml.data.versatile.VersatileMLDataSet;
import org.encog.ml.data.versatile.columns.ColumnDefinition;
import org.encog.ml.data.versatile.columns.ColumnType;
import org.encog.ml.data.versatile.sources.CSVDataSource;
import org.encog.ml.data.versatile.sources.VersatileDataSource;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.model.EncogModel;
import org.encog.util.csv.CSVFormat;
import org.encog.util.csv.ReadCSV;
import org.encog.util.simple.EncogUtility;

public class IrisClassify {

	public static void main(String[] args) {
		// TODO Auto-generated method stub
		
		//1.Mapping the input file
		String irisFile = "input/iris.csv";
		File file = new File(irisFile);
		VersatileDataSource source = new CSVDataSource(file, false, CSVFormat.DECIMAL_POINT);
		
		VersatileMLDataSet data = new VersatileMLDataSet(source);
		data.defineSourceColumn("sepal-length", 0, ColumnType.continuous);
		data.defineSourceColumn("sepal-width", 1, ColumnType.continuous);
		data.defineSourceColumn("petal-length", 2, ColumnType.continuous);
		data.defineSourceColumn("petal-width", 3, ColumnType.continuous);
		ColumnDefinition outputColumn = data.defineSourceColumn("species", 4, ColumnType.nominal);
		data.analyze();
		
		//2.specifying the model and normalizing
		data.defineSingleOutputOthersInput(outputColumn);
		EncogModel model = new EncogModel(data);
		//使用feed forward的神经网络,可以通过调整最后一个参数来实现算法的切换
		model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD);
		model.setReport(new ConsoleStatusReportable());
		data.normalize();
		
		//3.fitting the model
		/**Before we fit the model we hold back part of the data for a validation set. 
		 * We choose to hold back 30%. We chose to randomize the data set with a fixed seed value.
		 *  This fixed seed ensures that we get the same training and validation sets each time. 
		 *  This is a matter of preference. If you want a random sample each time then pass in the 
		 *  current time for the seed.
		 * */		
		model.holdBackValidation(0.3, true, 1001);
		model.selectTrainingType(data);
		// Finally, we fit the model with a k-fold cross-validation of size 5
		MLRegression bestMethod = (MLRegression)model.crossvalidate(5, true);
		
		//4.displaying the results
		System.out.println("Training error: "
		+ EncogUtility.calculateRegressionError(bestMethod, model.getTrainingDataset()));
		System.out.println("Validation error: "
		+ EncogUtility.calculateRegressionError(bestMethod, model.getValidationDataset()));
		
		NormalizationHelper helper = data.getNormHelper();
		System.out.println(helper.toString());
		System.out.println("Final model: "+bestMethod);
		
		//5.using the model and denormalizing
		ReadCSV csv = new ReadCSV(file, false, CSVFormat.DECIMAL_POINT);
		String [] line = new String[4];
		MLData input = helper.allocateInputVector();
		int count=0, total=0;
		while(csv.next()){
			StringBuilder result = new StringBuilder();
			line[0] = csv.get(0);
			line[1] = csv.get(1);
			line[2] = csv.get(2);
			line[3] = csv.get(3);
			String correct = csv.get(4);
			helper.normalizeInputVector(line, input.getData(), false);
			MLData output = bestMethod.compute(input);
			String irisChosen = helper.denormalizeOutputVectorToString(output)[0];
			if(irisChosen.equals(correct)){
				count++;
			}
			total++;
			result.append(Arrays.toString(line));
			result.append(" -> preducted: ");
			result.append(irisChosen);
			result.append("(correct: ");
			result.append(correct);
			result.append(")");
			System.out.println(result.toString());
		}
		System.out.println("AUC :"+(double)count/(double)total);
		//file.delete();
		Encog.getInstance().shutdown();
	}

}
第一步我们要获取输入文件,对CSVDataSource的构造函数传入一个File对象,接下来创建数据集对象,对数据的每一列进行申明,其中continuous表示连续型数值,而nominal表示如同名义标注的类型(如一个字符串)。
第二步需要选择模型,并进行格式的规范化。首先规定需要进行判断的列,这个例子中对最后一列“品种”进行预测;接着建立一个模型对象;然后选择算法,需要注意的是这个例子中采用的静态工厂的方法进行算法的选择,通过调整methodType参数可以很轻松的选择不同的算法类型,并且当调用data.normalize()方法后,Encog会自动根据所选算法对数据进行规范化处理。

第三步就是选择合适的模型了。model.holdBackValidation()方法中会设置三个很重要的参数validationPercent(评估数据占比), shuffle(是否打乱), seed(随机种子)。接着,把数据集加入到model中;之后调用model.Crossvalidate()方法获取一个最佳的训练模型。这里将训练数据进行5次划分,每一次都会生成一个不同的训练集和评估集。需要注意的是,这里每次划分的评估数据并不是之前在model.holdBackValidation()设置的,而是调用Crossvalidate()函数后内部对它进行的划分,中间结果会在控制台进行输出。


第四步是打印最优模型的结果,打印NormalizationHelper可以看到各列数据的情况(最小值、最大值、平均值、标准差)。最后选择出来的最优化模型是BasicNetwork,建立了3个神经层。这里的Validation error才是使用我们在holdBackValidation方法中设置的评估数据进行评估的结果。


最后一步就是使用模型进行分类预测了,这里还是使用的刚才的训练数据集进行预测,当然也可以使用新的数据集。可以看到,对整个数据集进行分类的准确率是0.98。




你可能感兴趣的:(Encog)