下面以函数拟合为例,说明DL4J的程序结构。参考源代码:
org.deeplearning4j.examples.feedforward.regression.RegressionMathFunctions
//生成一维向量,共nSamples个值,范围在区间[-10, 10]中
//nSamples 为样本数量,官方例子中默认1000
final INDArray x = Nd4j.linspace(-10,10,nSamples).reshape(nSamples, 1);
//计算sin(x),fn=sin(x)
final DataSetIterator iterator = getTrainingData(x,fn,batchSize,rng)
//函数getTrainingData()定义如下:
/** Create a DataSetIterator for training
* @param x X values
* @param function Function to evaluate
* @param batchSize Batch size (number of examples for every call of DataSetIterator.next())
* @param rng Random number generator (for repeatability)
*/
private static DataSetIterator getTrainingData(final INDArray x, final MathFunction function, final int batchSize, final Random rng) {
final INDArray y = function.getFunctionValues(x);
final DataSet allData = new DataSet(x,y);
final List list = allData.asList();
Collections.shuffle(list,rng);
return new ListDataSetIterator(list,batchSize);
}
//在主函数创建多层神经网络
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
private static MultiLayerConfiguration getDeepDenseLayerNetworkConfiguration() {
final int numHiddenNodes = 50;
return new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.weightInit(WeightInit.XAVIER)
.updater(Updater.NESTEROVS).momentum(0.9)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.activation(Activation.TANH).build())
.layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
.activation(Activation.TANH).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(numHiddenNodes).nOut(numOutputs).build())
.pretrain(false).backprop(true).build();
}
//在主函数中执行以下代码
final INDArray[] networkPredictions = new INDArray[nEpochs/ plotFrequency];
for( int i=0; iif((i+1) % plotFrequency == 0) {
networkPredictions[i/ plotFrequency] = net.output(x, false);
}
}
//在主函数中执行以下代码
plot(fn,x,fn.getFunctionValues(x),networkPredictions);
//定义作图函数plot()
private static void plot(final MathFunction function, final INDArray x, final INDArray y, final INDArray... predicted) {
final XYSeriesCollection dataSet = new XYSeriesCollection();
addSeries(dataSet,x,y,"True Function (Labels)");
for( int i=0; ifinal JFreeChart chart = ChartFactory.createXYLineChart(
"Regression Example - " + function.getName(), // chart title
"X", // x axis label
function.getName() + "(X)", // y axis label
dataSet, // data
PlotOrientation.VERTICAL,
true, // include legend
true, // tooltips
false // urls
);
final ChartPanel panel = new ChartPanel(chart);
final JFrame f = new JFrame();
f.add(panel);
f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
f.pack();
f.setVisible(true);
}
以下问题将在后续文章中逐一讲清楚:
1. Nd4j框架下的矩阵计算,向量化
2. 多层神经网络结构及参数
3. 训练、预测的策略