dl4j源码阅读心得及问题(Spark部分)

public class IrisLocal {

    public static void main(String[] args) throws Exception {
        SparkConf sparkConf = new SparkConf();
        sparkConf.setMaster("local[*]");
        sparkConf.setAppName("Iris");

        JavaSparkContext sc = new JavaSparkContext(sparkConf);

        //Load the data from local (driver) classpath into a JavaRDD, for training
            //CSVRecordReader converts CSV data (as a String) into usable format for network training
        RecordReader recordReader = new CSVRecordReader(0,",");
        File f = new File("src/main/resources/iris_shuffled_normalized_csv.txt");
        JavaRDD irisDataLines = sc.textFile(f.getAbsolutePath());
        //labelIndex变量指向目标向量在记录中的索引
        int labelIndex = 4;
        int numOutputClasses = 3;
        //分别为每条记录创建特征向量和目标向量,目标向量根据numOutputClasses变量的个数以及记录中所给的目标索引确定,如目标索引为2,numOutputClasses为3,则目标向量为<0,1,0>
        JavaRDD trainingData = irisDataLines.map(new RecordReaderFunction(recordReader, labelIndex, numOutputClasses)) ;


        //First: Create and initialize multi-layer network. Configuration is the same as in normal (non-distributed) DL4J training
        final int numInputs = 4;
        int outputNum = 3;
        int iterations = 1;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(12345)
                .iterations(iterations)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(0.5)
                .regularization(true).l2(1e-4)
                .activation("tanh")
                .weightInit(WeightInit.XAVIER)
                .list()
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3).build())
                .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build())
                .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax")
                        .nIn(2).nOut(outputNum).build())
                .backprop(true).pretrain(false)
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();



        //Second: Set up the Spark training.
        //Set up the TrainingMaster. The TrainingMaster controls how learning is actually executed on Spark
        //Here, we are using standard parameter averaging
        int examplesPerDataSetObject = 1;
        ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject)
                .workerPrefetchNumBatches(2)    //Asynchronously prefetch up to 2 batches
                .saveUpdater(true)
                .averagingFrequency(1)  //See comments on averaging frequency in LSTM example. Averaging every 1 iteration is inefficient in practical problems
                .batchSizePerWorker(8)  //Number of examples that each worker gets, per fit operation
                .build();
        SparkDl4jMultiLayer sparkNetwork = new SparkDl4jMultiLayer(sc,net,tm);

        int nEpochs = 100;
        for( int i=0; i

上面是程序示例,主要实现的功能是:在spark环境下进行神经网络的训练

Evaluation evaluation = sparkNetwork.evaluate(trainingData);
进入 SparkDl4jMultiLayer类中evalute方法,其中传递的参数分别为trainingdata,null,64
 
  
public Evaluation evaluate(JavaRDD data, List labelsList, int evalBatchSize) {
    Broadcast listBroadcast = labelsList == null?null:this.sc.broadcast(labelsList);
    JavaRDD evaluations = data.mapPartitions(new EvaluateFlatMapFunction(this.sc.broadcast(this.conf.toJson()), this.sc.broadcast(this.network.params()), evalBatchSize, listBroadcast));
    return (Evaluation)evaluations.reduce(new EvaluationReduceFunction());
}
data.mapPartitions()方法需要一个FlatMapFunction, Evaluation>参数,这里使用子类来实例化,new
EvaluateFlatMapFunction()方法中参数分别为SparkDl4jMultiLayer对象的json格式,
MultiLayerNetwork net = new MultiLayerNetwork(conf)
对象的json格式,第二个参数为
MultiLayerNetwork对象中的
flattenedParams
变量,该变量为神经网络中的权值加偏移量的总和,最后的两个参数为64和null。
EvaluateFlatMapFunction()方法中实现了上层接口FlatMapFunction, Evaluation>的call方法,该方法主要完成
神经网络训练结果的测试。
public Iterable call(Iterator dataSetIterator) throws Exception {
        if(!dataSetIterator.hasNext()) {
            return Collections.emptyList();
        } else {
            MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)this.json.getValue()));
            network.init();
            INDArray val = (INDArray)this.params.value();
            if(val.length() != network.numParams(false)) {
                throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
            } else {
                network.setParameters(val);
                Evaluation evaluation;
                if(this.labels != null) {
                    evaluation = new Evaluation((List)this.labels.getValue());
                } else {
                    evaluation = new Evaluation();
                }

                ArrayList collect = new ArrayList();
                int totalCount = 0;

                while(dataSetIterator.hasNext()) {
                    collect.clear();
                    int nExamples = 0;

                    DataSet data;
                    while(dataSetIterator.hasNext() && nExamples < this.evalBatchSize) {
                        data = (DataSet)dataSetIterator.next();
                        nExamples += data.numExamples();
                        collect.add(data);
                    }

                    totalCount += nExamples;
                    data = DataSet.merge(collect, false);
                    INDArray out;
                    if(data.hasMaskArrays()) {
                        out = network.output(data.getFeatureMatrix(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray());
                    } else {
                        out = network.output(data.getFeatureMatrix(), false);
                    }

                    if(data.getLabels().rank() == 3) {
                        if(data.getLabelsMaskArray() == null) {
                            evaluation.evalTimeSeries(data.getLabels(), out);
                        } else {
                            evaluation.evalTimeSeries(data.getLabels(), out, data.getLabelsMaskArray());
                        }
                    } else {
                        evaluation.eval(data.getLabels(), out);
                    }
                }

                if(log.isDebugEnabled()) {
                    log.debug("Evaluated {} examples ", Integer.valueOf(totalCount));
                }

                return Collections.singletonList(evaluation);
            }
        }
    }
具体什么意思不是很明白,
 
  
(Evaluation)evaluations.reduce(new EvaluationReduceFunction());
这个返回值主要是对各个分区最后得到的结果进行合并。
 
  
    public void merge(Evaluation other) {
        if(other != null) {
            this.truePositives.incrementAll(other.truePositives);
            this.falsePositives.incrementAll(other.falsePositives);
            this.trueNegatives.incrementAll(other.trueNegatives);
            this.falseNegatives.incrementAll(other.falseNegatives);
            if(this.confusion == null) {
                if(other.confusion != null) {
                    this.confusion = new ConfusionMatrix(other.confusion);
                }
            } else if(other.confusion != null) {
                this.confusion.add(other.confusion);
            }

            this.numRowCounter += other.numRowCounter;
            if(this.labelsList.isEmpty()) {
                this.labelsList.addAll(other.labelsList);
            }

        }
    }






你可能感兴趣的:(dl4j源码阅读心得及问题(Spark部分))