卷积神经网络从跟普通的的机器学习模型并不大一样,输入一般为一个高维矩阵,然后经过卷积、池化、卷积、池化、、到全连接(从矩阵转化一个向量)、softmax、方向传播 调整权值,
目前实现cnn的各种深度学习架构很多,下面用的deeplearning4j包实现,主要参考git上项目提供的例子,
package com.meituan.deeplearning4j;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class LenetMnistExample {
public static void main(String[] args) throws IOException {
int nChannels = 1;
int outputNum = 10;
int batchSize = 64;
int nEpochs = 1;
int iterations = 1;
int seed = 123;
System.out.println("load data");
DataSetIterator mnisTrain = new MnistDataSetIterator(batchSize, true,
12345);
DataSetIterator mnisTest = new MnistDataSetIterator(batchSize, false,
12345);
System.out.println("Builder model..");
Map lrSchedule = new HashMap();
System.out.println("build model....");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.regularization(true)
.l2(0.0005)
.learningRate(0.01)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(
OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS)
.momentum(0.9)
.list()
.layer(0,
new ConvolutionLayer.Builder(5, 5).nIn(nChannels)
.stride(1, 1).nOut(20)
.activation(Activation.IDENTITY).build())
.layer(1,
new SubsamplingLayer.Builder(
SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2).stride(2, 2).build())
.layer(2,
new ConvolutionLayer.Builder(5, 5)
// Note that nIn need not be specified in later
// layers
.stride(1, 1).nOut(50)
.activation(Activation.IDENTITY).build())
.layer(3,
new SubsamplingLayer.Builder(
SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2).stride(2, 2).build())
.layer(4,
new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500).build())
.layer(5,
new OutputLayer.Builder(
LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28, 28, 1)) // See
// note
// below
.backprop(true).pretrain(false);
MultiLayerConfiguration conf=builder.build();
MultiLayerNetwork model=new MultiLayerNetwork(conf);
model.init();
System.out.println("train model is start....");
for(int i=0;i<4;i++){
model.fit(mnisTrain);
System.out.println(" Completed epoch is :" + i);
System.out.println("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
while(mnisTest.hasNext()){
DataSet ds = mnisTest.next();
INDArray output = model.output(ds.getFeatureMatrix(), false);
eval.eval(ds.getLabels(), output);
}
System.out.println(eval.stats());
mnisTest.reset();
}
System.out.println("model finish");
}
}
cnn这些东西不用gpu什么的,训练速度确实很慢
迭代4词的结果,准确率达到了如下,把0分为0的情况有974个,0分为1的有1个。。。。:Examples labeled as 0 classified by model as 0: 974 times
Examples labeled as 0 classified by model as 1: 1 times
Examples labeled as 0 classified by model as 6: 1 times
Examples labeled as 0 classified by model as 7: 2 times
Examples labeled as 0 classified by model as 8: 1 times
Examples labeled as 0 classified by model as 9: 1 times
Examples labeled as 1 classified by model as 1: 1124 times
Examples labeled as 1 classified by model as 2: 4 times
Examples labeled as 1 classified by model as 3: 2 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 2 times
Examples labeled as 1 classified by model as 7: 1 times
Examples labeled as 1 classified by model as 8: 1 times
Examples labeled as 2 classified by model as 0: 2 times
Examples labeled as 2 classified by model as 2: 1027 times
Examples labeled as 2 classified by model as 6: 1 times
Examples labeled as 2 classified by model as 7: 2 times
Examples labeled as 3 classified by model as 0: 1 times
Examples labeled as 3 classified by model as 2: 2 times
Examples labeled as 3 classified by model as 3: 999 times
Examples labeled as 3 classified by model as 5: 3 times
Examples labeled as 3 classified by model as 7: 2 times
Examples labeled as 3 classified by model as 8: 3 times
Examples labeled as 4 classified by model as 2: 1 times
Examples labeled as 4 classified by model as 4: 975 times
Examples labeled as 4 classified by model as 6: 2 times
Examples labeled as 4 classified by model as 9: 4 times
Examples labeled as 5 classified by model as 0: 2 times
Examples labeled as 5 classified by model as 3: 5 times
Examples labeled as 5 classified by model as 5: 878 times
Examples labeled as 5 classified by model as 6: 2 times
Examples labeled as 5 classified by model as 7: 1 times
Examples labeled as 5 classified by model as 8: 3 times
Examples labeled as 5 classified by model as 9: 1 times
Examples labeled as 6 classified by model as 0: 4 times
Examples labeled as 6 classified by model as 1: 2 times
Examples labeled as 6 classified by model as 4: 1 times
Examples labeled as 6 classified by model as 5: 5 times
Examples labeled as 6 classified by model as 6: 944 times
Examples labeled as 6 classified by model as 8: 2 times
Examples labeled as 7 classified by model as 1: 4 times
Examples labeled as 7 classified by model as 2: 8 times
Examples labeled as 7 classified by model as 3: 1 times
Examples labeled as 7 classified by model as 7: 1010 times
Examples labeled as 7 classified by model as 8: 1 times
Examples labeled as 7 classified by model as 9: 4 times
Examples labeled as 8 classified by model as 0: 4 times
Examples labeled as 8 classified by model as 2: 3 times
Examples labeled as 8 classified by model as 3: 1 times
Examples labeled as 8 classified by model as 5: 1 times
Examples labeled as 8 classified by model as 7: 2 times
Examples labeled as 8 classified by model as 8: 959 times
Examples labeled as 8 classified by model as 9: 4 times
Examples labeled as 9 classified by model as 1: 2 times
Examples labeled as 9 classified by model as 2: 1 times
Examples labeled as 9 classified by model as 3: 2 times
Examples labeled as 9 classified by model as 4: 1 times
Examples labeled as 9 classified by model as 5: 4 times
Examples labeled as 9 classified by model as 7: 3 times
Examples labeled as 9 classified by model as 8: 2 times
Examples labeled as 9 classified by model as 9: 994 times
==========================Scores========================================
Accuracy: 0.9884
Precision: 0.9884
Recall: 0.9883
F1 Score: 0.9883
========================================================================
model finish