deeplearning4j训练推理案例2023——手写数字识别

文章目录

  • 1.minist数据集
  • 2.依赖包
  • 3.手写数字训练与推理
  • 4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples

1.minist数据集

下载链接 6W训练集,1W测试集

2.依赖包

主要是deeplearning4j、javacv的一些包,案例打出的jar包1.3G,pom来自github deeplearning子项目deeplearning4j-examples 的dl4j-examples模块



    4.0.0
    
        org.springframework.boot
        spring-boot-starter-parent
        2.7.9
        

    
    com.example
    demo
    0.0.1-SNAPSHOT
    demo
    demo
    
        1.0.0-M2.1
        nd4j-native
        17
        3.8.1
        3.3.1
        1.4.0
        2.4.3
        1.0.23
        1.0.13
        1.1.7
        UTF-8
        5.8.0-M1
        1.5.9
    
    
        
            
                org.bytedeco
                javacv-platform
                ${javacv.version}
            
        
    
    
        
            org.springframework.boot
            spring-boot-starter
        
        
            org.projectlombok
            lombok
        

        
            org.springframework.boot
            spring-boot-starter-test
            test
        

        
            org.nd4j
            ${nd4j.backend}
            ${dl4j-master.version}
        


        
            org.datavec
            datavec-api
            ${dl4j-master.version}
        
        
            org.datavec
            datavec-data-image
            ${dl4j-master.version}
        
        
            org.datavec
            datavec-local
            ${dl4j-master.version}
        
        
            org.deeplearning4j
            deeplearning4j-datasets
            ${dl4j-master.version}
        
        
            org.deeplearning4j
            deeplearning4j-core
            ${dl4j-master.version}
        

        
            org.deeplearning4j
            resources
            ${dl4j-master.version}
        

        
            org.deeplearning4j
            deeplearning4j-ui
            ${dl4j-master.version}
        
        
            org.deeplearning4j
            deeplearning4j-zoo
            ${dl4j-master.version}
        
        
        
            org.deeplearning4j
            deeplearning4j-parallel-wrapper
            ${dl4j-master.version}
        
        
        
            jfree
            jfreechart
            ${jfreechart.version}
        
        
            org.jfree
            jcommon
            ${jcommon.version}
        
        
        
            org.apache.httpcomponents
            httpclient
            4.3.5
        
        
            ch.qos.logback
            logback-classic
            ${logback.version}
        

        
            org.bytedeco
            javacv-platform
        
        
            org.nd4j
            nd4j-api
            1.0.0-M2.1
        

    

    
        
            
                org.springframework.boot
                spring-boot-maven-plugin
            
            
                org.apache.maven.plugins
                maven-compiler-plugin
                
                    17
                    17
                
            
        
    



3.手写数字训练与推理

1个epoch训练耗时100s,准确率达97%,详见代码注释,框架的api做得还比较好用

package ai;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.io.Assert;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;

import java.io.File;
import java.util.Random;

@Slf4j
public class LeNetMNISTReLu {
    private static final String DATASET_PATH_BASE = "D:\\";

    public static void main(String[] args) throws Exception {
        int height = 28;
        int width = 28;
        // 黑白图片通道只有一个
        int channels = 1;
        // 0-9十种数字
        int outputNum = 10;
        int batchSize = 64;
        // 这里一个epoch耗时约100s,3次准确率99%
        int nEpochs = 1;


        Assert.isTrue(new File(DATASET_PATH_BASE + "/mnist_png").exists(), "请下载压缩包并解压到" + DATASET_PATH_BASE);
        // 该label生成器会将数据所在父目录名作为label,要求目录名必须为数值,这里mnist数据集正好是放在0-9文件夹的
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        // 归一化(0-1)
        DataNormalization normalization = new ImagePreProcessingScaler();
        Random random = new Random(12345);
        log.info("训练集6W张...");
        File trainData = new File(DATASET_PATH_BASE + "/mnist_png/training");
        FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, random);
        ImageRecordReader trainRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
        trainRecordReader.initialize(trainSplit);
        DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, outputNum);
        normalization.fit(trainIter);
        trainIter.setPreProcessor(normalization); // 先像素归一化

        log.info("验证集1W张...");
        File validateData = new File(DATASET_PATH_BASE + "/mnist_png/testing");
        FileSplit validateSplit = new FileSplit(validateData, NativeImageLoader.ALLOWED_FORMATS, random);
        ImageRecordReader validateRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
        validateRecordReader.initialize(validateSplit);
        DataSetIterator validateIter = new RecordReaderDataSetIterator(validateRecordReader, batchSize, 1, outputNum);
        validateIter.setPreProcessor(normalization);

        // 训练集6W数据 每次迭代batchSize=64,故这里大概有1000次迭代
        // 学习率,每200个迭代更新一次学习率(步长),先大一点,还可以每个Epoch更新一次学习率
        MapSchedule mapSchedule = new MapSchedule.Builder(ScheduleType.ITERATION)
                .add(0, 0.06)
                .add(200, 0.05)
                .add(600, 0.028)
                .add(800, 0.006)
                .add(1000, 0.001)
                .build();

        // 超参
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(1)
                .l2(0.0005)
                .updater(new Nesterovs(mapSchedule))
                //.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) //该优化器导致长时间无法拟合
                .weightInit(WeightInit.XAVIER)
                .list()
                .layer(new ConvolutionLayer.Builder(5, 5)
                        .nIn(channels)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(new DenseLayer.Builder().activation(Activation.RELU)
                        .nOut(500)
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
                .build();

        // 神经网络对象构建
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        // 训练监控,每次迭代打印损失函数值
        net.setListeners(new ScoreIterationListener(10));
        // WEB UI监控训练过程
        //UIServer uiServer = UIServer.getInstance();
        //FileStatsStorage statsStorage = new FileStatsStorage(new File("D:\\ai-webui.dat"));
        //uiServer.attach(statsStorage);
        //net.setListeners(new StatsListener(statsStorage));
        log.info("网络参数个数{}", net.numParams());
        long startTime = System.currentTimeMillis();
        // 训练epochs轮
        for (int i = 0; i < nEpochs; i++) {
            log.info("Epoch=" + i);
            net.fit(trainIter);
            Evaluation eval = net.evaluate(validateIter);
            log.info(eval.stats());
            trainIter.reset();
            validateIter.reset();
        }
        log.info("训练耗时{}毫秒", System.currentTimeMillis() - startTime);
        // 保存模型
        File ministModelPath = new File(DATASET_PATH_BASE + "/ministModel.zip");
        ModelSerializer.writeModel(net, ministModelPath, true);
        // 推理逻辑:加载网络(模型)——>加载测试图片——>预测
        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(new File(DATASET_PATH_BASE + "/ministModel.zip"));
        NativeImageLoader imageLoader = new NativeImageLoader(height, width, channels);
        FileUtils.listFiles(new File("D:\\mnist_png\\testing"), null, true)
                .parallelStream().forEach(file -> {
                    try {
                        INDArray matrix = imageLoader.asMatrix(file);
                        INDArray output = network.output(matrix);
                        // 取最可能的预测结果
                        int predictedValue = Nd4j.argMax(output, 1).getInt(0);
                        // 数字图片按数值放在每个文件夹的,故图片所在文件夹名即为真实值
                        String realValue = file.getParentFile().getName();
                        log.info("真实值:{},预测值:{}", realValue, predictedValue);
                        Assert.isTrue(predictedValue == Integer.parseInt(realValue), file.getAbsolutePath() + "预测错误");
                    } catch (Exception e) {
                        log.warn(e.getMessage(), e);
                    }
                });
    }
}

4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples

deeplearning4j-examples 参考其readme文档

你可能感兴趣的:(java,ai,deep,learning)