继续的Encog-quickstart学习,这一次要实现一个从CSV 文件中读取数据,调用feed forward方式的神经网络进行分类的问题。数据集在https://archive.ics.uci.edu/ml/datasets/Iris 中,数据格式如下:
数据集中并没有列名,数据之间使用逗号隔开,通过4个数值型的属性来预测当前样本究竟是哪一种Iris。让我们先来看看代码:
import java.io.File;
import java.util.Arrays;
import org.encog.ConsoleStatusReportable;
import org.encog.Encog;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.versatile.NormalizationHelper;
import org.encog.ml.data.versatile.VersatileMLDataSet;
import org.encog.ml.data.versatile.columns.ColumnDefinition;
import org.encog.ml.data.versatile.columns.ColumnType;
import org.encog.ml.data.versatile.sources.CSVDataSource;
import org.encog.ml.data.versatile.sources.VersatileDataSource;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.model.EncogModel;
import org.encog.util.csv.CSVFormat;
import org.encog.util.csv.ReadCSV;
import org.encog.util.simple.EncogUtility;
public class IrisClassify {
public static void main(String[] args) {
// TODO Auto-generated method stub
//1.Mapping the input file
String irisFile = "input/iris.csv";
File file = new File(irisFile);
VersatileDataSource source = new CSVDataSource(file, false, CSVFormat.DECIMAL_POINT);
VersatileMLDataSet data = new VersatileMLDataSet(source);
data.defineSourceColumn("sepal-length", 0, ColumnType.continuous);
data.defineSourceColumn("sepal-width", 1, ColumnType.continuous);
data.defineSourceColumn("petal-length", 2, ColumnType.continuous);
data.defineSourceColumn("petal-width", 3, ColumnType.continuous);
ColumnDefinition outputColumn = data.defineSourceColumn("species", 4, ColumnType.nominal);
data.analyze();
//2.specifying the model and normalizing
data.defineSingleOutputOthersInput(outputColumn);
EncogModel model = new EncogModel(data);
//使用feed forward的神经网络,可以通过调整最后一个参数来实现算法的切换
model.selectMethod(data, MLMethodFactory.TYPE_FEEDFORWARD);
model.setReport(new ConsoleStatusReportable());
data.normalize();
//3.fitting the model
/**Before we fit the model we hold back part of the data for a validation set.
* We choose to hold back 30%. We chose to randomize the data set with a fixed seed value.
* This fixed seed ensures that we get the same training and validation sets each time.
* This is a matter of preference. If you want a random sample each time then pass in the
* current time for the seed.
* */
model.holdBackValidation(0.3, true, 1001);
model.selectTrainingType(data);
// Finally, we fit the model with a k-fold cross-validation of size 5
MLRegression bestMethod = (MLRegression)model.crossvalidate(5, true);
//4.displaying the results
System.out.println("Training error: "
+ EncogUtility.calculateRegressionError(bestMethod, model.getTrainingDataset()));
System.out.println("Validation error: "
+ EncogUtility.calculateRegressionError(bestMethod, model.getValidationDataset()));
NormalizationHelper helper = data.getNormHelper();
System.out.println(helper.toString());
System.out.println("Final model: "+bestMethod);
//5.using the model and denormalizing
ReadCSV csv = new ReadCSV(file, false, CSVFormat.DECIMAL_POINT);
String [] line = new String[4];
MLData input = helper.allocateInputVector();
int count=0, total=0;
while(csv.next()){
StringBuilder result = new StringBuilder();
line[0] = csv.get(0);
line[1] = csv.get(1);
line[2] = csv.get(2);
line[3] = csv.get(3);
String correct = csv.get(4);
helper.normalizeInputVector(line, input.getData(), false);
MLData output = bestMethod.compute(input);
String irisChosen = helper.denormalizeOutputVectorToString(output)[0];
if(irisChosen.equals(correct)){
count++;
}
total++;
result.append(Arrays.toString(line));
result.append(" -> preducted: ");
result.append(irisChosen);
result.append("(correct: ");
result.append(correct);
result.append(")");
System.out.println(result.toString());
}
System.out.println("AUC :"+(double)count/(double)total);
//file.delete();
Encog.getInstance().shutdown();
}
}
第一步我们要获取输入文件,对CSVDataSource的构造函数传入一个File对象,接下来创建数据集对象,对数据的每一列进行申明,其中continuous表示连续型数值,而nominal表示如同名义标注的类型(如一个字符串)。
第三步就是选择合适的模型了。在model.holdBackValidation()方法中会设置三个很重要的参数validationPercent(评估数据占比), shuffle(是否打乱), seed(随机种子)。接着,把数据集加入到model中;之后调用model.Crossvalidate()方法获取一个最佳的训练模型。这里将训练数据进行5次划分,每一次都会生成一个不同的训练集和评估集。需要注意的是,这里每次划分的评估数据并不是之前在model.holdBackValidation()中设置的,而是调用Crossvalidate()函数后内部对它进行的划分,中间结果会在控制台进行输出。
第四步是打印最优模型的结果,打印NormalizationHelper可以看到各列数据的情况(最小值、最大值、平均值、标准差)。最后选择出来的最优化模型是BasicNetwork,建立了3个神经层。这里的Validation error才是使用我们在holdBackValidation方法中设置的评估数据进行评估的结果。