这篇博客主要是浅显的谈一谈deeplearning4j的多层神经网络是如何构建的,本人水平有限,因此不足之处还是希望有人能够指出。
下面给出了官网上的一个关于手写体识别的例子,代码上的注视我已经是写好了,直接把代码列出来,大概花了4个多小时,不得不承认这个太慢了。我的电脑配置如下:
package org.deeplearning4j.examples.convolution;
import opennlp.tools.util.model.ModelUtil;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** * Created by agibsonccc on 9/16/15. */
public class LenetMnistExample {
private static final Logger log = LoggerFactory.getLogger(LenetMnistExample.class);
public static void main(String[] args) throws Exception {
int nChannels = 1; //通道数目
int outputNum = 10; //输出层神经元数目
int batchSize = 1000; //batch的大小 图像块
int nEpochs = 10; // epoch数目
int iterations = 1; //迭代次数
int seed = 123; // 随机数
log.info("Load data....");
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize,true,12345);
/* * 在使用的时候需要事先创建DataSetIterator,一般是相当于自己集成吧? * */
/** System.out.println("Total examples in the iterator : " + mnistTrain.totalExamples());// 60000个例子 System.out.println("Input columns for the dataset " + mnistTrain.inputColumns());// 28*28 每一个图像大小 System.out.println("The number of labels for the dataset : " + mnistTrain.totalOutcomes()); // 一共10类 System.out.println("Batch size: "+mnistTrain.batch());// 每一次批次训练的输入样本数目 **/
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
log.info("Build model....");
//配置神经网络
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.regularization(true).l2(0.0005)// 使用L2正则化
.learningRate(0.01) // 学习步长
.weightInit(WeightInit.XAVIER) //权值初始化
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) //采用随机梯度下降法
.updater(Updater.NESTEROVS).momentum(0.9) //全职更新方式
.list(4) //神经网络非输入层的层数
.layer(0, new ConvolutionLayer.Builder(2, 5) // 卷积层 5*5, 这里可以自己进行
.nIn(nChannels) // 通道数
.stride(1, 1) // 卷积神经网络进行卷积时的步长
.nOut(20).dropOut(0.5) // gropOut 是规定该层神经元激活的数目0.5
.activation("relu") // 激活函数
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) // 采样层
.kernelSize(2,2) // 核大小 2*2
.stride(2,2)
.build())
.layer(2, new DenseLayer.Builder().activation("relu") //全连阶层,稠密层数
.nOut(500).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation("softmax") //激活函数
.build())
.backprop(true). //是否支持后向传播
pretrain(false); //是否预先训练
new ConvolutionLayerSetup(builder,28,28,1);//构建多层感知机
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init(); //模型初始化
log.info("Train model....");
model.setListeners(new ScoreIterationListener(1));
for( int i=0; i<nEpochs; i++ ) {
model.fit(mnistTrain);
log.info("*** Completed epoch {} ***", i);
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
while(mnistTest.hasNext()){
DataSet ds = mnistTest.next();
INDArray output = model.output(ds.getFeatureMatrix());
eval.eval(ds.getLabels(), output);
}
log.info(eval.stats());
mnistTest.reset();
// 将训练好的神经网络
}
log.info("****************Example finished********************");
}
}
如果对神经网络的结构不清楚的,上述代码就是天书,但是如果对神经网络熟悉的人,基本上看着字面的意思也就能猜出来了。下面先说一个Layer这个抽象类,通过阅读发现,这种编写方式特别像scala程序的编写风格,里面有一些地方我不知道是干什么用的就没有注释:
/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */
package org.deeplearning4j.nn.conf.layers;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.JsonTypeInfo.As;
import com.fasterxml.jackson.annotation.JsonTypeInfo.Id;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.weights.WeightInit;
/** * A neural network layer. */
@JsonTypeInfo(use=Id.NAME, include=As.WRAPPER_OBJECT)
// 各种神经网络的层
@JsonSubTypes(value={
@JsonSubTypes.Type(value = AutoEncoder.class, name = "autoEncoder"),
@JsonSubTypes.Type(value = ConvolutionLayer.class, name = "convolution"),
@JsonSubTypes.Type(value = ImageLSTM.class, name = "imageLSTM"),
@JsonSubTypes.Type(value = GravesLSTM.class, name = "gravesLSTM"),
@JsonSubTypes.Type(value = GRU.class, name = "gru"),
@JsonSubTypes.Type(value = OutputLayer.class, name = "output"),
@JsonSubTypes.Type(value = RnnOutputLayer.class, name = "rnnoutput"),
@JsonSubTypes.Type(value = RBM.class, name = "RBM"),
@JsonSubTypes.Type(value = DenseLayer.class, name = "dense"),
@JsonSubTypes.Type(value = RecursiveAutoEncoder.class, name = "recursiveAutoEncoder"),
@JsonSubTypes.Type(value = SubsamplingLayer.class, name = "subsampling"),
@JsonSubTypes.Type(value = LocalResponseNormalization.class, name = "localResponseNormalization"),
})
@Data
@NoArgsConstructor
public abstract class Layer implements Serializable, Cloneable {
protected String layerName; //名称
protected String activationFunction; //激活函数
protected WeightInit weightInit; //权值
protected double biasInit;//偏置
protected Distribution dist;//分布
protected double learningRate;//学习率
//learning rate after n iterations
protected Map<Integer,Double> learningRateAfter;//第n次迭代后的学习率
protected double lrScoreBasedDecay;
protected double momentum; //运动惯量,权值优化的时候会用到
//momentum after n iterations
protected Map<Integer,Double> momentumAfter;
protected double l1;//L1正则花
protected double l2;//L2正则花
protected double dropOut;//dropOut
protected Updater updater;//权值更新的方式
//adadelta - weight for how much to consider previous history
protected double rho; //这个是权重更新的参数
protected double rmsDecay;
protected double adamMeanDecay = 0.9;
protected double adamVarDecay = 0.999;
//梯度
protected GradientNormalization gradientNormalization = GradientNormalization.None; //Clipping, rescale based on l2 norm, etc
protected double gradientNormalizationThreshold = 1.0; //Threshold for l2 and element-wise gradient clipping
public Layer(Builder builder) {
this.layerName = builder.layerName;
this.activationFunction = builder.activationFunction;
this.weightInit = builder.weightInit;
this.biasInit = builder.biasInit;
this.dist = builder.dist;
this.learningRate = builder.learningRate;
this.learningRateAfter = builder.learningRateAfter;
this.lrScoreBasedDecay = builder.lrScoreBasedDecay;
this.momentum = builder.momentum;
this.momentumAfter = builder.momentumAfter;
this.l1 = builder.l1;
this.l2 = builder.l2;
this.dropOut = builder.dropOut;
this.updater = builder.updater;
this.rho = builder.rho;
this.rmsDecay = builder.rmsDecay;
this.adamMeanDecay = builder.adamMeanDecay;
this.adamVarDecay = builder.adamVarDecay;
this.gradientNormalization = builder.gradientNormalization;
this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold;
}
//按照字面意思是复制
@Override
public Layer clone() {
try {
Layer clone = (Layer) super.clone();
if(clone.dist != null) clone.dist = clone.dist.clone();
if(clone.learningRateAfter != null) clone.learningRateAfter = new HashMap<>(clone.learningRateAfter);
if(clone.momentumAfter != null) clone.momentumAfter = new HashMap<>(clone.momentumAfter);
return clone;
} catch (CloneNotSupportedException e) {
throw new RuntimeException(e);
}
}
//抽象类,建立神经网络层
public abstract static class Builder<T extends Builder<T>> {
protected String layerName = "genisys";//名称
protected String activationFunction = null;//激活函数
protected WeightInit weightInit = null;//权值
protected double biasInit = Double.NaN;//偏置
protected Distribution dist = null;//分布
protected double learningRate = Double.NaN;//学习率
protected Map<Integer,Double> learningRateAfter = null;
protected double lrScoreBasedDecay = Double.NaN;
protected double momentum = Double.NaN;//运动惯量
protected Map<Integer,Double> momentumAfter = null;//
protected double l1 = Double.NaN;//
protected double l2 = Double.NaN;//
protected double dropOut = Double.NaN;//
protected Updater updater = null;//
protected double rho = Double.NaN;//
protected double rmsDecay = Double.NaN;//
protected double adamMeanDecay = Double.NaN;//
protected double adamVarDecay = Double.NaN;//
protected GradientNormalization gradientNormalization = null;//
protected double gradientNormalizationThreshold = Double.NaN;//
/**Layer name assigns layer string name. * Allows easier differentiation between layers. */
public T name(String layerName) {
this.layerName = layerName;
return (T) this;
}
/**Layer activation function. * Typical values include:<br> * "relu" (rectified linear), "tanh", "sigmoid", "softmax", * "hardtanh", "leakyrelu", "maxout", "softsign", "softplus" */
public T activation(String activationFunction) {
this.activationFunction = activationFunction;
return (T) this;
}
/** Weight initialization scheme. * @see org.deeplearning4j.nn.weights.WeightInit */
public T weightInit(WeightInit weightInit) {
this.weightInit = weightInit;
return (T) this;
}
public T biasInit(double biasInit) {
this.biasInit = biasInit;
return (T) this;
}
/** Distribution to sample initial weights from. Used in conjunction with * .weightInit(WeightInit.DISTRIBUTION). */
public T dist(Distribution dist){
this.dist = dist;
return (T) this;
}
/** Learning rate. Defaults to 1e-1*/
public T learningRate(double learningRate){
this.learningRate = learningRate;
return (T)this;
}
/** Learning rate schedule. Map of the iteration to the learning rate to apply at that iteration. */
public T learningRateAfter(Map<Integer, Double> learningRateAfter) {
this.learningRateAfter = learningRateAfter;
return (T) this;
}
/** Rate to decrease learningRate by when the score stops improving. * Learning rate is multiplied by this rate so ideally keep between 0 and 1. */
public T learningRateScoreBasedDecayRate(double lrScoreBasedDecay) {
this.lrScoreBasedDecay = lrScoreBasedDecay;
return (T) this;
}
/** L1 regularization coefficient.*/
public T l1(double l1){
this.l1 = l1;
return (T)this;
}
/** L2 regularization coefficient. */
public T l2(double l2){
this.l2 = l2;
return (T)this;
}
public T dropOut(double dropOut) {
this.dropOut = dropOut;
return (T) this;
}
/** Momentum rate. */
public T momentum(double momentum) {
this.momentum = momentum;
return (T)this;
}
/** Momentum schedule. Map of the iteration to the momentum rate to apply at that iteration. */
public T momentumAfter(Map<Integer, Double> momentumAfter) {
this.momentumAfter = momentumAfter;
return (T) this;
}
/** Gradient updater. For example, SGD for standard stochastic gradient descent, NESTEROV for Nesterov momentum, * RSMPROP for RMSProp, etc. * @see org.deeplearning4j.nn.conf.Updater */
public T updater(Updater updater){
this.updater = updater;
return (T) this;
}
/** * Ada delta coefficient * @param rho */
public T rho(double rho) {
this.rho = rho;
return (T) this;
}
/** Decay rate for RMSProp. Only applies if using .updater(Updater.RMSPROP) */
public T rmsDecay(double rmsDecay) {
this.rmsDecay = rmsDecay;
return (T) this;
}
/** Mean decay rate for Adam updater. Only applies if using .updater(Updater.ADAM) */
public T adamMeanDecay(double adamMeanDecay) {
this.adamMeanDecay = adamMeanDecay;
return (T) this;
}
/** Variance decay rate for Adam updater. Only applies if using .updater(Updater.ADAM) */
public T adamVarDecay(double adamVarDecay) {
this.adamVarDecay = adamVarDecay;
return (T) this;
}
/** Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc. * @param gradientNormalization Type of normalization to use. Defaults to None. * @see org.deeplearning4j.nn.conf.GradientNormalization */
public T gradientNormalization(GradientNormalization gradientNormalization ){
this.gradientNormalization = gradientNormalization;
return (T) this;
}
/** Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue<br> * Not used otherwise.<br> * L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping. */
public T gradientNormalizationThreshold(double threshold){
this.gradientNormalizationThreshold = threshold;
return (T) this;
}
public abstract <E extends Layer> E build();
}
}
下面的是抽象类FeedForwardLayer,里面主要是有神经网络的输入数目和输出数目。
package org.deeplearning4j.nn.conf.layers;
import lombok.*;
/** * Created by jeffreytang on 7/21/15. */
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public abstract class FeedForwardLayer extends Layer {
protected int nIn;
protected int nOut;
public FeedForwardLayer(Builder builder) {
super(builder);
this.nIn = builder.nIn;
this.nOut = builder.nOut;
}
public abstract static class Builder<T extends Builder<T>> extends Layer.Builder<T> {
protected int nIn = 0;
protected int nOut = 0;
public T nIn(int nIn) {
this.nIn = nIn;
return (T) this;
}
public T nOut(int nOut) {
this.nOut = nOut;
return (T) this;
}
}
}
对于卷基层ConvolutionLayer 有特殊的卷积核,卷积步长,以及填充等操作的定义:
package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.nd4j.linalg.convolution.Convolution;
/** * @author Adam Gibson */
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class ConvolutionLayer extends FeedForwardLayer {
protected Convolution.Type convolutionType; //用于定义卷积的类型,主要有"full","same","valid"三种卷积方式,这个和matlab一样的
protected int[] kernelSize; // Square filter 核滤波器的大小
protected int[] stride; // Default is 2. Down-sample by a factor of 2 卷积步长
protected int[] padding; //填充
private ConvolutionLayer(Builder builder) {
super(builder);
this.convolutionType = builder.convolutionType;
if(builder.kernelSize.length != 2)
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
this.kernelSize = builder.kernelSize;
if(builder.stride.length != 2)
throw new IllegalArgumentException("Invalid stride, must be length 2");
this.stride = builder.stride;
this.padding = builder.padding;
}
@Override
public ConvolutionLayer clone() {
ConvolutionLayer clone = (ConvolutionLayer) super.clone();
if(clone.kernelSize != null) clone.kernelSize = clone.kernelSize.clone();
if(clone.stride != null) clone.stride = clone.stride.clone();
if(clone.padding != null) clone.padding = clone.padding.clone();
return clone;
}
@AllArgsConstructor
public static class Builder extends FeedForwardLayer.Builder<Builder> {
private Convolution.Type convolutionType = Convolution.Type.VALID;
private int[] kernelSize = new int[] {5, 5};
private int[] stride = new int[] {1,1};
private int[] padding = new int[] {0, 0};
public Builder(int[] kernelSize, int[] stride, int[] padding) {
this.kernelSize = kernelSize;
this.stride = stride;
this.padding = padding;
}
public Builder(int[] kernelSize, int[] stride) {
this.kernelSize = kernelSize;
this.stride = stride;
}
public Builder(int... kernelSize) {
this.kernelSize = kernelSize;
}
public Builder() {}
public Builder convolutionType(Convolution.Type convolutionType) {
this.convolutionType = convolutionType;
return this;
}
/** * Size of the convolution * rows/columns * @param kernelSize the height and width of the * kernel * @return */
public Builder kernelSize(int... kernelSize){
this.kernelSize = kernelSize;
return this;
}
public Builder stride(int... stride){
this.stride = stride;
return this;
}
public Builder padding(int... padding){
this.padding = padding;
return this;
}
@Override
@SuppressWarnings("unchecked")
public ConvolutionLayer build() {
return new ConvolutionLayer(this);
}
}
}
对于采样层也有特殊的结构,比如下采样的方法:
package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.weights.WeightInit;
/** * Subsampling layer also referred to as pooling in convolution neural nets * * Supports the following pooling types: * MAX * AVG * NON * @author Adam Gibson */
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class SubsamplingLayer extends Layer {
protected PoolingType poolingType; //池化的类型,主要有"MAX", "AVG", "SUM", "NONE"四种类型
protected int[] kernelSize; // Same as filter size from the last conv layer 核的大小
protected int[] stride; // Default is 2. Down-sample by a factor of 2 采样步长
protected int[] padding; //填充
public enum PoolingType {
MAX, AVG, SUM, NONE
}
private SubsamplingLayer(Builder builder) {
super(builder);
this.poolingType = builder.poolingType;
if(builder.kernelSize.length != 2)
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
this.kernelSize = builder.kernelSize;
if(builder.stride.length != 2)
throw new IllegalArgumentException("Invalid stride, must be length 2");
this.stride = builder.stride;
this.padding = builder.padding;
}
@Override
public SubsamplingLayer clone() {
SubsamplingLayer clone = (SubsamplingLayer) super.clone();
if(clone.kernelSize != null) clone.kernelSize = clone.kernelSize.clone();
if(clone.stride != null) clone.stride = clone.stride.clone();
if(clone.padding != null) clone.padding = clone.padding.clone();
return clone;
}
@AllArgsConstructor
public static class Builder extends Layer.Builder<Builder> {
private PoolingType poolingType = PoolingType.MAX;
private int[] kernelSize = new int[] {1, 1}; // Same as filter size from the last conv layer
private int[] stride = new int[] {2, 2}; // Default is 2. Down-sample by a factor of 2
private int[] padding = new int[] {0, 0};
public Builder(PoolingType poolingType, int[] kernelSize, int[] stride) {
this.poolingType = poolingType;
this.kernelSize = kernelSize;
this.stride = stride;
}
public Builder(PoolingType poolingType, int[] kernelSize) {
this.poolingType = poolingType;
this.kernelSize = kernelSize;
}
public Builder(int[] kernelSize, int[] stride, int[] padding) {
this.kernelSize = kernelSize;
this.stride = stride;
this.padding = padding;
}
public Builder(int[] kernelSize, int[] stride) {
this.kernelSize = kernelSize;
this.stride = stride;
}
public Builder(int... kernelSize) {
this.kernelSize = kernelSize;
}
public Builder(PoolingType poolingType) {
this.poolingType = poolingType;
}
public Builder() {}
@Override
@SuppressWarnings("unchecked")
public SubsamplingLayer build() {
return new SubsamplingLayer(this);
}
public Builder poolingType(PoolingType poolingType){
this.poolingType = poolingType;
return this;
}
public Builder kernelSize(int... kernelSize){
this.kernelSize = kernelSize;
return this;
}
public Builder stride(int... stride){
this.stride = stride;
return this;
}
public Builder padding(int... padding){
this.padding = padding;
return this;
}
}
}
全链接层其实包含2层,一个输入层,一个输出层(没有隐藏层),具体的激活函数等由其父类定义
/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */
package org.deeplearning4j.nn.conf.layers;
import lombok.*;
/**Dense layer: fully connected feed forward layer trainable by backprop. */
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class DenseLayer extends FeedForwardLayer {
private DenseLayer(Builder builder) {
super(builder);
}
@AllArgsConstructor
public static class Builder extends FeedForwardLayer.Builder<Builder> {
@Override
@SuppressWarnings("unchecked")
public DenseLayer build() {
return new DenseLayer(this);
}
}
}
输出层比上面的都要特殊,需要定义误差函数:
/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */
package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
/** * Output layer with different objective co-occurrences for different objectives. * This includes classification as well as prediction * */
@Data @NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class OutputLayer extends BaseOutputLayer {
protected OutputLayer(Builder builder) {
super(builder);
}
@NoArgsConstructor
public static class Builder extends BaseOutputLayer.Builder<Builder> {
public Builder(LossFunction lossFunction) {
this.lossFunction = lossFunction;
}
@Override
@SuppressWarnings("unchecked")
public OutputLayer build() {
return new OutputLayer(this);
}
}
}
deeplearning4j的误差函数还是比较多的,具体的如下:
/* * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * * */
package org.nd4j.linalg.lossfunctions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import static org.nd4j.linalg.ops.transforms.Transforms.*;
/** * Central class for loss functions * @author Adam Gibson */
public class LossFunctions {
/** * Generic scoring function. * Note that an IllegalArgumentException is thrown if the given * loss function is custom. An alternative mechanism for scoring * (preferrably with a function name and the op factory) should be used instead. * * @param labels the labels to score * @param lossFunction the loss function to use * @param z the output function * @param l2 the l2 regularization term (0.5 * l2Coeff * sum w^2) * @param l1 the l1 regularization term (l1Coeff * sum |w|) * @param useRegularization whether to use regularization * @return the score for the given parameters */
public static double score(INDArray labels, LossFunction lossFunction, INDArray z, double l2, double l1,boolean useRegularization) {
return LossCalculation.builder()
.l1(l1).lossFunction(lossFunction)
.l2(l2).labels(labels)
.z(z)
.useRegularization(useRegularization)
.build().score();
}
/** * Generic scoring function. * Note that an IllegalArgumentException is thrown if the given * loss function is custom. An alternative mechanism for scoring * (preferrably with a function name and the op factory) should be used instead. * * @param labels the labels to score * @param lossFunction the loss function to use * @param z the output function * @param l2 the l2 coefficient * @param useRegularization whether to use regularization * @return the score for the given parameters */
public static double score(INDArray labels, LossFunction lossFunction, INDArray z, double l2, boolean useRegularization) {
double ret = 0.0;
double reg = 0.5 * l2;
if (!Arrays.equals(labels.shape(), z.shape()))
throw new IllegalArgumentException("Output and labels must be same length");
boolean oldEnforce = Nd4j.ENFORCE_NUMERICAL_STABILITY;
Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
switch (lossFunction) {
case CUSTOM: throw new IllegalStateException("Unable to score custom operation. Please define an alternative mechanism");
case RECONSTRUCTION_CROSSENTROPY:
INDArray xEntLogZ2 = log(z);
INDArray xEntOneMinusLabelsOut2 = labels.rsub(1);
INDArray xEntOneMinusLogOneMinusZ2 = log(z).rsubi(1);
ret = labels.mul(xEntLogZ2).add(xEntOneMinusLabelsOut2).muli(xEntOneMinusLogOneMinusZ2).sum(1).meanNumber().doubleValue();
break;
case MCXENT:
INDArray sums = log(z);
INDArray columnSums = labels.mul(sums);
ret = -columnSums.sumNumber().doubleValue();
break;
case XENT:
INDArray xEntLogZ = log(z);
INDArray xEntOneMinusLabelsOut = labels.rsub(1);
INDArray xEntOneMinusLogOneMinusZ = log(z).rsubi(1);
ret = labels.mul(xEntLogZ).add(xEntOneMinusLabelsOut).muli(xEntOneMinusLogOneMinusZ).sum(1).sumNumber().doubleValue();
break;
case RMSE_XENT:
INDArray rmseXentDiff = labels.sub(z);
INDArray squaredrmseXentDiff = pow(rmseXentDiff, 2.0);
INDArray sqrt = sqrt(squaredrmseXentDiff);
ret = sqrt.sum(1).sumNumber().doubleValue();
break;
case MSE:
INDArray mseDelta = labels.sub(z);
ret = 0.5 * pow(mseDelta, 2).sum(1).sumNumber().doubleValue();
break;
case EXPLL:
INDArray expLLLogZ = log(z);
ret = z.sub(labels.mul(expLLLogZ)).sum(1).sumNumber().doubleValue();
break;
case SQUARED_LOSS:
ret = pow(labels.sub(z), 2).sum(1).sumNumber().doubleValue();
break;
case NEGATIVELOGLIKELIHOOD:
INDArray sums2 = log(z);
INDArray columnSums2 = labels.mul(sums2);
ret = -columnSums2.sumNumber().doubleValue();
break;
}
if (useRegularization)
ret += reg;
ret /= (double) labels.size(0);
Nd4j.ENFORCE_NUMERICAL_STABILITY = oldEnforce;
return ret;
}
/** * MSE: Mean Squared Error: Linear Regression * EXPLL: Exponential log likelihood: Poisson Regression * XENT: Cross Entropy: Binary Classification * MCXENT: Multiclass Cross Entropy * RMSE_XENT: RMSE Cross Entropy * SQUARED_LOSS: Squared Loss * RECONSTRUCTION_CROSSENTROPY: Reconstruction Cross Entropy * NEGATIVELOGLIKELIHOOD: Negative Log Likelihood * CUSTOM: Define your own loss function */
public enum LossFunction {
MSE,
EXPLL,
XENT,
MCXENT,
RMSE_XENT,
SQUARED_LOSS,
RECONSTRUCTION_CROSSENTROPY,
NEGATIVELOGLIKELIHOOD,
CUSTOM
}
}