deeplearning4j之卷积神经网络实现

卷积神经网络从跟普通的的机器学习模型并不大一样,输入一般为一个高维矩阵,然后经过卷积、池化、卷积、池化、、到全连接(从矩阵转化一个向量)、softmax、方向传播 调整权值,

目前实现cnn的各种深度学习架构很多,下面用的deeplearning4j包实现,主要参考git上项目提供的例子,


package com.meituan.deeplearning4j;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class LenetMnistExample {

	public static void main(String[] args) throws IOException {

		int nChannels = 1;
		int outputNum = 10;
		int batchSize = 64;
		int nEpochs = 1;
		int iterations = 1;
		int seed = 123;
		System.out.println("load data");
		DataSetIterator mnisTrain = new MnistDataSetIterator(batchSize, true,
				12345);
		DataSetIterator mnisTest = new MnistDataSetIterator(batchSize, false,
				12345);
		System.out.println("Builder model..");
		Map lrSchedule = new HashMap();
		System.out.println("build model....");
		MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
				.seed(seed)
				.iterations(iterations)
				.regularization(true)
				.l2(0.0005)
				.learningRate(0.01)
				.weightInit(WeightInit.XAVIER)
				.optimizationAlgo(
						OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
				.updater(Updater.NESTEROVS)
				.momentum(0.9)
				.list()
				.layer(0,
						new ConvolutionLayer.Builder(5, 5).nIn(nChannels)
								.stride(1, 1).nOut(20)
								.activation(Activation.IDENTITY).build())
				.layer(1,
						new SubsamplingLayer.Builder(
								SubsamplingLayer.PoolingType.MAX)
								.kernelSize(2, 2).stride(2, 2).build())
				.layer(2,
						new ConvolutionLayer.Builder(5, 5)
								// Note that nIn need not be specified in later
								// layers
								.stride(1, 1).nOut(50)
								.activation(Activation.IDENTITY).build())
				.layer(3,
						new SubsamplingLayer.Builder(
								SubsamplingLayer.PoolingType.MAX)
								.kernelSize(2, 2).stride(2, 2).build())
				.layer(4,
						new DenseLayer.Builder().activation(Activation.RELU)
								.nOut(500).build())
				.layer(5,
						new OutputLayer.Builder(
								LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
								.nOut(outputNum).activation(Activation.SOFTMAX)
								.build())
				.setInputType(InputType.convolutionalFlat(28, 28, 1)) // See
																		// note
																		// below
				.backprop(true).pretrain(false);
		
		MultiLayerConfiguration conf=builder.build();
		MultiLayerNetwork  model=new MultiLayerNetwork(conf);
		model.init();
		System.out.println("train model is start....");
		for(int i=0;i<4;i++){
			model.fit(mnisTrain);
			System.out.println(" Completed epoch is :" + i);

	            System.out.println("Evaluate model....");
	            Evaluation eval = new Evaluation(outputNum);
	            while(mnisTest.hasNext()){
	                DataSet ds = mnisTest.next();
	                INDArray output = model.output(ds.getFeatureMatrix(), false);
	                eval.eval(ds.getLabels(), output);
	            }
	            System.out.println(eval.stats());
	            mnisTest.reset();
			
		}
		System.out.println("model finish");

	}

}



cnn这些东西不用gpu什么的,训练速度确实很慢

迭代4词的结果,准确率达到了如下,把0分为0的情况有974个,0分为1的有1个。。。。:


Examples labeled as 0 classified by model as 0: 974 times

Examples labeled as 0 classified by model as 1: 1 times

Examples labeled as 0 classified by model as 6: 1 times

Examples labeled as 0 classified by model as 7: 2 times

Examples labeled as 0 classified by model as 8: 1 times

Examples labeled as 0 classified by model as 9: 1 times

Examples labeled as 1 classified by model as 1: 1124 times

Examples labeled as 1 classified by model as 2: 4 times

Examples labeled as 1 classified by model as 3: 2 times

Examples labeled as 1 classified by model as 5: 1 times

Examples labeled as 1 classified by model as 6: 2 times

Examples labeled as 1 classified by model as 7: 1 times

Examples labeled as 1 classified by model as 8: 1 times

Examples labeled as 2 classified by model as 0: 2 times

Examples labeled as 2 classified by model as 2: 1027 times

Examples labeled as 2 classified by model as 6: 1 times

Examples labeled as 2 classified by model as 7: 2 times

Examples labeled as 3 classified by model as 0: 1 times

Examples labeled as 3 classified by model as 2: 2 times

Examples labeled as 3 classified by model as 3: 999 times

Examples labeled as 3 classified by model as 5: 3 times

Examples labeled as 3 classified by model as 7: 2 times

Examples labeled as 3 classified by model as 8: 3 times

Examples labeled as 4 classified by model as 2: 1 times

Examples labeled as 4 classified by model as 4: 975 times

Examples labeled as 4 classified by model as 6: 2 times

Examples labeled as 4 classified by model as 9: 4 times

Examples labeled as 5 classified by model as 0: 2 times

Examples labeled as 5 classified by model as 3: 5 times

Examples labeled as 5 classified by model as 5: 878 times

Examples labeled as 5 classified by model as 6: 2 times

Examples labeled as 5 classified by model as 7: 1 times

Examples labeled as 5 classified by model as 8: 3 times

Examples labeled as 5 classified by model as 9: 1 times

Examples labeled as 6 classified by model as 0: 4 times

Examples labeled as 6 classified by model as 1: 2 times

Examples labeled as 6 classified by model as 4: 1 times

Examples labeled as 6 classified by model as 5: 5 times

Examples labeled as 6 classified by model as 6: 944 times

Examples labeled as 6 classified by model as 8: 2 times

Examples labeled as 7 classified by model as 1: 4 times

Examples labeled as 7 classified by model as 2: 8 times

Examples labeled as 7 classified by model as 3: 1 times

Examples labeled as 7 classified by model as 7: 1010 times

Examples labeled as 7 classified by model as 8: 1 times

Examples labeled as 7 classified by model as 9: 4 times

Examples labeled as 8 classified by model as 0: 4 times

Examples labeled as 8 classified by model as 2: 3 times

Examples labeled as 8 classified by model as 3: 1 times

Examples labeled as 8 classified by model as 5: 1 times

Examples labeled as 8 classified by model as 7: 2 times

Examples labeled as 8 classified by model as 8: 959 times

Examples labeled as 8 classified by model as 9: 4 times

Examples labeled as 9 classified by model as 1: 2 times

Examples labeled as 9 classified by model as 2: 1 times

Examples labeled as 9 classified by model as 3: 2 times

Examples labeled as 9 classified by model as 4: 1 times

Examples labeled as 9 classified by model as 5: 4 times

Examples labeled as 9 classified by model as 7: 3 times

Examples labeled as 9 classified by model as 8: 2 times

Examples labeled as 9 classified by model as 9: 994 times



==========================Scores========================================

 Accuracy:        0.9884

 Precision:       0.9884

 Recall:          0.9883

 F1 Score:        0.9883

========================================================================

model finish




你可能感兴趣的:(java,机器学习)