深度学习之deeplearning4j(二)

这篇博客主要是浅显的谈一谈deeplearning4j的多层神经网络是如何构建的,本人水平有限,因此不足之处还是希望有人能够指出。
下面给出了官网上的一个关于手写体识别的例子,代码上的注视我已经是写好了,直接把代码列出来,大概花了4个多小时,不得不承认这个太慢了。我的电脑配置如下:

package org.deeplearning4j.examples.convolution;

import opennlp.tools.util.model.ModelUtil;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** * Created by agibsonccc on 9/16/15. */
public class LenetMnistExample {
    private static final Logger log = LoggerFactory.getLogger(LenetMnistExample.class);

    public static void main(String[] args) throws Exception {
        int nChannels = 1; //通道数目
        int outputNum = 10; //输出层神经元数目
        int batchSize = 1000; //batch的大小 图像块 
        int nEpochs = 10; // epoch数目
        int iterations = 1; //迭代次数
        int seed = 123; // 随机数

        log.info("Load data....");
        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize,true,12345);
        /* * 在使用的时候需要事先创建DataSetIterator,一般是相当于自己集成吧? * */
        /** System.out.println("Total examples in the iterator : " + mnistTrain.totalExamples());// 60000个例子 System.out.println("Input columns for the dataset " + mnistTrain.inputColumns());// 28*28 每一个图像大小 System.out.println("The number of labels for the dataset : " + mnistTrain.totalOutcomes()); // 一共10类 System.out.println("Batch size: "+mnistTrain.batch());// 每一次批次训练的输入样本数目 **/
        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345); 
        log.info("Build model....");
        //配置神经网络
        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations) 
                .regularization(true).l2(0.0005)// 使用L2正则化
                .learningRate(0.01) // 学习步长
                .weightInit(WeightInit.XAVIER) //权值初始化
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) //采用随机梯度下降法
                .updater(Updater.NESTEROVS).momentum(0.9) //全职更新方式
                .list(4) //神经网络非输入层的层数
                .layer(0, new ConvolutionLayer.Builder(2, 5) // 卷积层 5*5, 这里可以自己进行
                        .nIn(nChannels) // 通道数
                        .stride(1, 1) // 卷积神经网络进行卷积时的步长
                        .nOut(20).dropOut(0.5) // gropOut 是规定该层神经元激活的数目0.5
                        .activation("relu") // 激活函数
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) // 采样层
                        .kernelSize(2,2) // 核大小 2*2
                        .stride(2,2)
                        .build())
                .layer(2, new DenseLayer.Builder().activation("relu")  //全连阶层,稠密层数
                        .nOut(500).build())
                .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation("softmax") //激活函数
                        .build())
                .backprop(true). //是否支持后向传播
                pretrain(false); //是否预先训练
        new ConvolutionLayerSetup(builder,28,28,1);//构建多层感知机

        MultiLayerConfiguration conf = builder.build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf); 
        model.init(); //模型初始化

        log.info("Train model....");
        model.setListeners(new ScoreIterationListener(1));
        for( int i=0; i<nEpochs; i++ ) {
            model.fit(mnistTrain);
            log.info("*** Completed epoch {} ***", i);
            log.info("Evaluate model....");
            Evaluation eval = new Evaluation(outputNum);
            while(mnistTest.hasNext()){
                DataSet ds = mnistTest.next();
                INDArray output = model.output(ds.getFeatureMatrix());
                eval.eval(ds.getLabels(), output);
            }
            log.info(eval.stats());
            mnistTest.reset();
            // 将训练好的神经网络
        }
        log.info("****************Example finished********************");
    }
}

如果对神经网络的结构不清楚的,上述代码就是天书,但是如果对神经网络熟悉的人,基本上看着字面的意思也就能猜出来了。下面先说一个Layer这个抽象类,通过阅读发现,这种编写方式特别像scala程序的编写风格,里面有一些地方我不知道是干什么用的就没有注释:

/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */

package org.deeplearning4j.nn.conf.layers;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.JsonTypeInfo.As;
import com.fasterxml.jackson.annotation.JsonTypeInfo.Id;

import lombok.Data;
import lombok.NoArgsConstructor;

import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.weights.WeightInit;

/** * A neural network layer. */
@JsonTypeInfo(use=Id.NAME, include=As.WRAPPER_OBJECT)
// 各种神经网络的层
@JsonSubTypes(value={
        @JsonSubTypes.Type(value = AutoEncoder.class, name = "autoEncoder"),
        @JsonSubTypes.Type(value = ConvolutionLayer.class, name = "convolution"),
        @JsonSubTypes.Type(value = ImageLSTM.class, name = "imageLSTM"),
        @JsonSubTypes.Type(value = GravesLSTM.class, name = "gravesLSTM"),
        @JsonSubTypes.Type(value = GRU.class, name = "gru"),
        @JsonSubTypes.Type(value = OutputLayer.class, name = "output"),
        @JsonSubTypes.Type(value = RnnOutputLayer.class, name = "rnnoutput"),
        @JsonSubTypes.Type(value = RBM.class, name = "RBM"),
        @JsonSubTypes.Type(value = DenseLayer.class, name = "dense"),
        @JsonSubTypes.Type(value = RecursiveAutoEncoder.class, name = "recursiveAutoEncoder"),
        @JsonSubTypes.Type(value = SubsamplingLayer.class, name = "subsampling"),
        @JsonSubTypes.Type(value = LocalResponseNormalization.class, name = "localResponseNormalization"),
        })
@Data
@NoArgsConstructor
public abstract class Layer implements Serializable, Cloneable {
    protected String layerName; //名称
    protected String activationFunction; //激活函数
    protected WeightInit weightInit; //权值
    protected double biasInit;//偏置
    protected Distribution dist;//分布
    protected double learningRate;//学习率
    //learning rate after n iterations
    protected Map<Integer,Double> learningRateAfter;//第n次迭代后的学习率
    protected double lrScoreBasedDecay;
    protected double momentum; //运动惯量,权值优化的时候会用到
    //momentum after n iterations
    protected Map<Integer,Double> momentumAfter;
    protected double l1;//L1正则花
    protected double l2;//L2正则花
    protected double dropOut;//dropOut
    protected Updater updater;//权值更新的方式
    //adadelta - weight for how much to consider previous history
    protected double rho; //这个是权重更新的参数
    protected double rmsDecay;
    protected double adamMeanDecay = 0.9;
    protected double adamVarDecay = 0.999;
    //梯度
    protected GradientNormalization gradientNormalization = GradientNormalization.None; //Clipping, rescale based on l2 norm, etc
    protected double gradientNormalizationThreshold = 1.0;   //Threshold for l2 and element-wise gradient clipping

    public Layer(Builder builder) {
        this.layerName = builder.layerName;
        this.activationFunction = builder.activationFunction;
        this.weightInit = builder.weightInit;
        this.biasInit = builder.biasInit;
        this.dist = builder.dist;
        this.learningRate = builder.learningRate;
        this.learningRateAfter = builder.learningRateAfter;
        this.lrScoreBasedDecay = builder.lrScoreBasedDecay;
        this.momentum = builder.momentum;
        this.momentumAfter = builder.momentumAfter;
        this.l1 = builder.l1;
        this.l2 = builder.l2;
        this.dropOut = builder.dropOut;
        this.updater = builder.updater;
        this.rho = builder.rho;
        this.rmsDecay = builder.rmsDecay;
        this.adamMeanDecay = builder.adamMeanDecay;
        this.adamVarDecay = builder.adamVarDecay;
        this.gradientNormalization = builder.gradientNormalization;
        this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
    }

    //按照字面意思是复制
    @Override
    public Layer clone() {
        try {
            Layer clone = (Layer) super.clone();
            if(clone.dist != null) clone.dist = clone.dist.clone();
            if(clone.learningRateAfter != null) clone.learningRateAfter = new HashMap<>(clone.learningRateAfter);
            if(clone.momentumAfter != null) clone.momentumAfter = new HashMap<>(clone.momentumAfter);
            return clone;
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    //抽象类,建立神经网络层
    public abstract static class Builder<T extends Builder<T>> {
        protected String layerName = "genisys";//名称
        protected String activationFunction = null;//激活函数
        protected WeightInit weightInit = null;//权值
        protected double biasInit = Double.NaN;//偏置
        protected Distribution dist = null;//分布
        protected double learningRate = Double.NaN;//学习率
        protected Map<Integer,Double> learningRateAfter = null;
        protected double lrScoreBasedDecay = Double.NaN;
        protected double momentum = Double.NaN;//运动惯量
        protected Map<Integer,Double> momentumAfter = null;//
        protected double l1 = Double.NaN;//
        protected double l2 = Double.NaN;//
        protected double dropOut = Double.NaN;//
        protected Updater updater = null;//
        protected double rho = Double.NaN;//
        protected double rmsDecay = Double.NaN;//
        protected double adamMeanDecay = Double.NaN;//
        protected double adamVarDecay = Double.NaN;//
        protected GradientNormalization gradientNormalization = null;//
        protected double gradientNormalizationThreshold = Double.NaN;//


        /**Layer name assigns layer string name. * Allows easier differentiation between layers. */
        public T name(String layerName) {
            this.layerName = layerName;
            return (T) this;
        }


        /**Layer activation function. * Typical values include:<br> * "relu" (rectified linear), "tanh", "sigmoid", "softmax", * "hardtanh", "leakyrelu", "maxout", "softsign", "softplus" */
        public T activation(String activationFunction) {
            this.activationFunction = activationFunction;
            return (T) this;
        }

        /** Weight initialization scheme. * @see org.deeplearning4j.nn.weights.WeightInit */
        public T weightInit(WeightInit weightInit) {
            this.weightInit = weightInit;
            return (T) this;
        }

        public T biasInit(double biasInit) {
            this.biasInit = biasInit;
            return (T) this;
        }

        /** Distribution to sample initial weights from. Used in conjunction with * .weightInit(WeightInit.DISTRIBUTION). */
        public T dist(Distribution dist){
            this.dist = dist;
            return (T) this;
        }

        /** Learning rate. Defaults to 1e-1*/
        public T learningRate(double learningRate){
            this.learningRate = learningRate;
            return (T)this;
        }

        /** Learning rate schedule. Map of the iteration to the learning rate to apply at that iteration. */
        public T learningRateAfter(Map<Integer, Double> learningRateAfter) {
            this.learningRateAfter = learningRateAfter;
            return (T) this;
        }

        /** Rate to decrease learningRate by when the score stops improving. * Learning rate is multiplied by this rate so ideally keep between 0 and 1. */
        public T learningRateScoreBasedDecayRate(double lrScoreBasedDecay) {
            this.lrScoreBasedDecay = lrScoreBasedDecay;
            return (T) this;
        }

        /** L1 regularization coefficient.*/
        public T l1(double l1){
            this.l1 = l1;
            return (T)this;
        }

        /** L2 regularization coefficient. */
        public T l2(double l2){
            this.l2 = l2;
            return (T)this;
        }

        public T dropOut(double dropOut) {
            this.dropOut = dropOut;
            return (T) this;
        }

        /** Momentum rate. */
        public T momentum(double momentum) {
            this.momentum = momentum;
            return (T)this;
        }

        /** Momentum schedule. Map of the iteration to the momentum rate to apply at that iteration. */
        public T momentumAfter(Map<Integer, Double> momentumAfter) {
            this.momentumAfter = momentumAfter;
            return (T) this;
        }

        /** Gradient updater. For example, SGD for standard stochastic gradient descent, NESTEROV for Nesterov momentum, * RSMPROP for RMSProp, etc. * @see org.deeplearning4j.nn.conf.Updater */
        public T updater(Updater updater){
            this.updater = updater;
            return (T) this;
        }

        /** * Ada delta coefficient * @param rho */
        public T rho(double rho) {
            this.rho = rho;
            return (T) this;
        }

        /** Decay rate for RMSProp. Only applies if using .updater(Updater.RMSPROP) */
        public T rmsDecay(double rmsDecay) {
            this.rmsDecay = rmsDecay;
            return (T) this;
        }

        /** Mean decay rate for Adam updater. Only applies if using .updater(Updater.ADAM) */
        public T adamMeanDecay(double adamMeanDecay) {
            this.adamMeanDecay = adamMeanDecay;
            return (T) this;
        }

        /** Variance decay rate for Adam updater. Only applies if using .updater(Updater.ADAM) */
        public T adamVarDecay(double adamVarDecay) {
            this.adamVarDecay = adamVarDecay;
            return (T) this;
        }

        /** Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc. * @param gradientNormalization Type of normalization to use. Defaults to None. * @see org.deeplearning4j.nn.conf.GradientNormalization */
        public T gradientNormalization(GradientNormalization gradientNormalization ){
            this.gradientNormalization = gradientNormalization;
            return (T) this;
        }

        /** Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue<br> * Not used otherwise.<br> * L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping. */
        public T gradientNormalizationThreshold(double threshold){
            this.gradientNormalizationThreshold = threshold;
            return (T) this;
        }

        public abstract <E extends Layer> E build();
    }
}

下面的是抽象类FeedForwardLayer,里面主要是有神经网络的输入数目和输出数目。

package org.deeplearning4j.nn.conf.layers;

import lombok.*;

/** * Created by jeffreytang on 7/21/15. */
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public abstract class FeedForwardLayer extends Layer {
    protected int nIn;
    protected int nOut;

    public FeedForwardLayer(Builder builder) {
        super(builder);
        this.nIn = builder.nIn;
        this.nOut = builder.nOut;
    }

    public abstract static class Builder<T extends Builder<T>> extends Layer.Builder<T> {
        protected int nIn = 0;
        protected int nOut = 0;

        public T nIn(int nIn) {
            this.nIn = nIn;
            return (T) this;
        }

        public T nOut(int nOut) {
            this.nOut = nOut;
            return (T) this;
        }
    }
}

对于卷基层ConvolutionLayer 有特殊的卷积核,卷积步长,以及填充等操作的定义:

package org.deeplearning4j.nn.conf.layers;

import lombok.*;

import org.nd4j.linalg.convolution.Convolution;

/** * @author Adam Gibson */
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class ConvolutionLayer extends FeedForwardLayer {
    protected Convolution.Type convolutionType; //用于定义卷积的类型,主要有"full","same","valid"三种卷积方式,这个和matlab一样的
    protected int[] kernelSize; // Square filter 核滤波器的大小
    protected int[] stride; // Default is 2. Down-sample by a factor of 2 卷积步长
    protected int[] padding; //填充

    private ConvolutionLayer(Builder builder) {
        super(builder);
        this.convolutionType = builder.convolutionType;
        if(builder.kernelSize.length != 2)
            throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
        this.kernelSize = builder.kernelSize;
        if(builder.stride.length != 2)
            throw new IllegalArgumentException("Invalid stride, must be length 2");
        this.stride = builder.stride;
        this.padding = builder.padding;
    }

    @Override
    public ConvolutionLayer clone() {
        ConvolutionLayer clone = (ConvolutionLayer) super.clone();
        if(clone.kernelSize != null) clone.kernelSize = clone.kernelSize.clone();
        if(clone.stride != null) clone.stride = clone.stride.clone();
        if(clone.padding != null) clone.padding = clone.padding.clone();
        return clone;
    }

    @AllArgsConstructor
    public static class Builder extends FeedForwardLayer.Builder<Builder> {
        private Convolution.Type convolutionType = Convolution.Type.VALID;
        private int[] kernelSize = new int[] {5, 5};
        private int[] stride = new int[] {1,1};
        private int[] padding = new int[] {0, 0};


        public Builder(int[] kernelSize, int[] stride, int[] padding) {
            this.kernelSize = kernelSize;
            this.stride = stride;
            this.padding = padding;
        }

        public Builder(int[] kernelSize, int[] stride) {
            this.kernelSize = kernelSize;
            this.stride = stride;
        }

        public Builder(int... kernelSize) {
            this.kernelSize = kernelSize;
        }

        public Builder() {}

        public Builder convolutionType(Convolution.Type convolutionType) {
            this.convolutionType = convolutionType;
            return this;
        }

        /** * Size of the convolution * rows/columns * @param kernelSize the height and width of the * kernel * @return */
        public Builder kernelSize(int... kernelSize){
            this.kernelSize = kernelSize;
            return this;
        }

        public Builder stride(int... stride){
            this.stride = stride;
            return this;
        }

        public Builder padding(int... padding){
            this.padding = padding;
            return this;
        }

        @Override
        @SuppressWarnings("unchecked")
        public ConvolutionLayer build() {
            return new ConvolutionLayer(this);
        }
    }
}

对于采样层也有特殊的结构,比如下采样的方法:

package org.deeplearning4j.nn.conf.layers;

import lombok.*;

import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.weights.WeightInit;

/** * Subsampling layer also referred to as pooling in convolution neural nets * * Supports the following pooling types: * MAX * AVG * NON * @author Adam Gibson */

@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class SubsamplingLayer extends Layer {

    protected PoolingType poolingType; //池化的类型,主要有"MAX", "AVG", "SUM", "NONE"四种类型
    protected int[] kernelSize; // Same as filter size from the last conv layer 核的大小
    protected int[] stride; // Default is 2. Down-sample by a factor of 2 采样步长
    protected int[] padding; //填充

    public enum PoolingType {
        MAX, AVG, SUM, NONE
    }

    private SubsamplingLayer(Builder builder) {
        super(builder);
        this.poolingType = builder.poolingType;
        if(builder.kernelSize.length != 2)
            throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
        this.kernelSize = builder.kernelSize;
        if(builder.stride.length != 2)
            throw new IllegalArgumentException("Invalid stride, must be length 2");
        this.stride = builder.stride;
        this.padding = builder.padding;
    }

    @Override
    public SubsamplingLayer clone() {
        SubsamplingLayer clone = (SubsamplingLayer) super.clone();

        if(clone.kernelSize != null) clone.kernelSize = clone.kernelSize.clone();
        if(clone.stride != null) clone.stride = clone.stride.clone();
        if(clone.padding != null) clone.padding = clone.padding.clone();
        return clone;
    }

    @AllArgsConstructor
    public static class Builder extends Layer.Builder<Builder> {
        private PoolingType poolingType = PoolingType.MAX;
        private int[] kernelSize = new int[] {1, 1}; // Same as filter size from the last conv layer
        private int[] stride = new int[] {2, 2}; // Default is 2. Down-sample by a factor of 2
        private int[] padding = new int[] {0, 0};

        public Builder(PoolingType poolingType, int[] kernelSize, int[] stride) {
            this.poolingType = poolingType;
            this.kernelSize = kernelSize;
            this.stride = stride;
        }

        public Builder(PoolingType poolingType, int[] kernelSize) {
            this.poolingType = poolingType;
            this.kernelSize = kernelSize;
        }

        public Builder(int[] kernelSize, int[] stride, int[] padding) {
            this.kernelSize = kernelSize;
            this.stride = stride;
            this.padding = padding;
        }

        public Builder(int[] kernelSize, int[] stride) {
            this.kernelSize = kernelSize;
            this.stride = stride;
        }

        public Builder(int... kernelSize) {
            this.kernelSize = kernelSize;
        }

        public Builder(PoolingType poolingType) {
            this.poolingType = poolingType;
        }

        public Builder() {}

        @Override
        @SuppressWarnings("unchecked")
        public SubsamplingLayer build() {
            return new SubsamplingLayer(this);
        }

        public Builder poolingType(PoolingType poolingType){
            this.poolingType = poolingType;
            return this;
        }

        public Builder kernelSize(int... kernelSize){
            this.kernelSize = kernelSize;
            return this;
        }

        public Builder stride(int... stride){
            this.stride = stride;
            return this;
        }

        public Builder padding(int... padding){
            this.padding = padding;
            return this;
        }
    }

}

全链接层其实包含2层,一个输入层,一个输出层(没有隐藏层),具体的激活函数等由其父类定义

/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */

package org.deeplearning4j.nn.conf.layers;

import lombok.*;

/**Dense layer: fully connected feed forward layer trainable by backprop. */
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class DenseLayer extends FeedForwardLayer {

    private DenseLayer(Builder builder) {
        super(builder);
    }

    @AllArgsConstructor
    public static class Builder extends FeedForwardLayer.Builder<Builder> {

        @Override
        @SuppressWarnings("unchecked")
        public DenseLayer build() {
            return new DenseLayer(this);
        }
    }

}

输出层比上面的都要特殊,需要定义误差函数:

/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */

package org.deeplearning4j.nn.conf.layers;

import lombok.*;

import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;

/** * Output layer with different objective co-occurrences for different objectives. * This includes classification as well as prediction * */
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class OutputLayer extends BaseOutputLayer {

    protected OutputLayer(Builder builder) {
        super(builder);
    }

    @NoArgsConstructor
    public static class Builder extends BaseOutputLayer.Builder<Builder> {

        public Builder(LossFunction lossFunction) {
            this.lossFunction = lossFunction;
        }

        @Override
        @SuppressWarnings("unchecked")
        public OutputLayer build() {
            return new OutputLayer(this);
        }
    }
}

deeplearning4j的误差函数还是比较多的,具体的如下:

/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * * */

package org.nd4j.linalg.lossfunctions;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.Arrays;

import static org.nd4j.linalg.ops.transforms.Transforms.*;


/** * Central class for loss functions * @author Adam Gibson */
public class LossFunctions {

    /** * Generic scoring function. * Note that an IllegalArgumentException is thrown if the given * loss function is custom. An alternative mechanism for scoring * (preferrably with a function name and the op factory) should be used instead. * * @param labels the labels to score * @param lossFunction the loss function to use * @param z the output function * @param l2 the l2 regularization term (0.5 * l2Coeff * sum w^2) * @param l1 the l1 regularization term (l1Coeff * sum |w|) * @param useRegularization whether to use regularization * @return the score for the given parameters */
    public static double score(INDArray labels, LossFunction lossFunction, INDArray z, double l2, double l1,boolean useRegularization) {
        return LossCalculation.builder()
                .l1(l1).lossFunction(lossFunction)
                .l2(l2).labels(labels)
                .z(z)
                .useRegularization(useRegularization)
                .build().score();
    }

    /** * Generic scoring function. * Note that an IllegalArgumentException is thrown if the given * loss function is custom. An alternative mechanism for scoring * (preferrably with a function name and the op factory) should be used instead. * * @param labels the labels to score * @param lossFunction the loss function to use * @param z the output function * @param l2 the l2 coefficient * @param useRegularization whether to use regularization * @return the score for the given parameters */
    public static double score(INDArray labels, LossFunction lossFunction, INDArray z, double l2, boolean useRegularization) {
        double ret = 0.0;
        double reg = 0.5 * l2;
        if (!Arrays.equals(labels.shape(), z.shape()))
            throw new IllegalArgumentException("Output and labels must be same length");
        boolean oldEnforce = Nd4j.ENFORCE_NUMERICAL_STABILITY;
        Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
        switch (lossFunction) {
            case CUSTOM: throw new IllegalStateException("Unable to score custom operation. Please define an alternative mechanism");
            case RECONSTRUCTION_CROSSENTROPY:
                INDArray xEntLogZ2 = log(z);
                INDArray xEntOneMinusLabelsOut2 = labels.rsub(1);
                INDArray xEntOneMinusLogOneMinusZ2 = log(z).rsubi(1);
                ret = labels.mul(xEntLogZ2).add(xEntOneMinusLabelsOut2).muli(xEntOneMinusLogOneMinusZ2).sum(1).meanNumber().doubleValue();
                break;
            case MCXENT:
                INDArray sums = log(z);
                INDArray columnSums = labels.mul(sums);
                ret = -columnSums.sumNumber().doubleValue();
                break;
            case XENT:
                INDArray xEntLogZ = log(z);
                INDArray xEntOneMinusLabelsOut = labels.rsub(1);
                INDArray xEntOneMinusLogOneMinusZ = log(z).rsubi(1);
                ret = labels.mul(xEntLogZ).add(xEntOneMinusLabelsOut).muli(xEntOneMinusLogOneMinusZ).sum(1).sumNumber().doubleValue();
                break;
            case RMSE_XENT:
                INDArray rmseXentDiff = labels.sub(z);
                INDArray squaredrmseXentDiff = pow(rmseXentDiff, 2.0);
                INDArray sqrt = sqrt(squaredrmseXentDiff);
                ret = sqrt.sum(1).sumNumber().doubleValue();
                break;
            case MSE:
                INDArray mseDelta = labels.sub(z);
                ret = 0.5 * pow(mseDelta, 2).sum(1).sumNumber().doubleValue();
                break;
            case EXPLL:
                INDArray expLLLogZ = log(z);
                ret = z.sub(labels.mul(expLLLogZ)).sum(1).sumNumber().doubleValue();
                break;
            case SQUARED_LOSS:
                ret = pow(labels.sub(z), 2).sum(1).sumNumber().doubleValue();
                break;
            case NEGATIVELOGLIKELIHOOD:
                INDArray sums2 = log(z);
                INDArray columnSums2 = labels.mul(sums2);
                ret = -columnSums2.sumNumber().doubleValue();
                break;


        }

        if (useRegularization)
            ret += reg;


        ret /= (double) labels.size(0);
        Nd4j.ENFORCE_NUMERICAL_STABILITY = oldEnforce;
        return ret;

    }


    /** * MSE: Mean Squared Error: Linear Regression * EXPLL: Exponential log likelihood: Poisson Regression * XENT: Cross Entropy: Binary Classification * MCXENT: Multiclass Cross Entropy * RMSE_XENT: RMSE Cross Entropy * SQUARED_LOSS: Squared Loss * RECONSTRUCTION_CROSSENTROPY: Reconstruction Cross Entropy * NEGATIVELOGLIKELIHOOD: Negative Log Likelihood * CUSTOM: Define your own loss function */
    public  enum LossFunction {
        MSE,
        EXPLL,
        XENT,
        MCXENT,
        RMSE_XENT,
        SQUARED_LOSS,
        RECONSTRUCTION_CROSSENTROPY,
        NEGATIVELOGLIKELIHOOD,
        CUSTOM
    }


}

你可能感兴趣的:(深度学习)