这个例程比较简单,写这篇博客主要时为了做一些简单的记录,以防止后面遇到浪费不必要的时间。
这个例程包含读入CSV数据,对数据进行归一化处理,然后创建简单的神经网络,训练然后预测。
package org.deeplearning4j.examples.dataExamples;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author Adam Gibson
*/
public class CSVExample {
private static Logger log = LoggerFactory.getLogger(CSVExample.class); 创建log,便于打印日志
public static void main(String[] args) throws Exception {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0; 有些文件具有表头,有些没有。即读取文件时需要跳过的行数
String delimiter = ","; 数据之间的分隔符
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter); 文件读取器
recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); 从磁盘读取文件
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
//5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int labelIndex = 4; //label所在的位置,
//3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int numClasses = 3; 分多少类
//Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
int batchSize = 150;
数据共有多少条?还是要批处理的数量? //将数据存入迭代器,参数分别为:读取器 批处理的量 label的位置 分多少类 DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); DataSet allData = iterator.next(); 将数据转为DataSet格式 allData.shuffle(); 混洗,打乱数据 //分成训练集和测试集 SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training DataSet trainingData = testAndTrain.getTrain(); 获得训练集 DataSet testData = testAndTrain.getTest(); 获得测试集 System.out.println("allData = "+allData.numExamples()+" train = "+trainingData.numExamples()); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); 对数据进行归一化
//Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.fit(trainingData); 计算训练集的均值和方差
结果如下:
如有问题,请批评指正。谢谢