下载链接 6W训练集,1W测试集
主要是deeplearning4j、javacv的一些包,案例打出的jar包1.3G,pom来自github deeplearning子项目deeplearning4j-examples 的dl4j-examples模块
4.0.0
org.springframework.boot
spring-boot-starter-parent
2.7.9
com.example
demo
0.0.1-SNAPSHOT
demo
demo
1.0.0-M2.1
nd4j-native
17
3.8.1
3.3.1
1.4.0
2.4.3
1.0.23
1.0.13
1.1.7
UTF-8
5.8.0-M1
1.5.9
org.bytedeco
javacv-platform
${javacv.version}
org.springframework.boot
spring-boot-starter
org.projectlombok
lombok
org.springframework.boot
spring-boot-starter-test
test
org.nd4j
${nd4j.backend}
${dl4j-master.version}
org.datavec
datavec-api
${dl4j-master.version}
org.datavec
datavec-data-image
${dl4j-master.version}
org.datavec
datavec-local
${dl4j-master.version}
org.deeplearning4j
deeplearning4j-datasets
${dl4j-master.version}
org.deeplearning4j
deeplearning4j-core
${dl4j-master.version}
org.deeplearning4j
resources
${dl4j-master.version}
org.deeplearning4j
deeplearning4j-ui
${dl4j-master.version}
org.deeplearning4j
deeplearning4j-zoo
${dl4j-master.version}
org.deeplearning4j
deeplearning4j-parallel-wrapper
${dl4j-master.version}
jfree
jfreechart
${jfreechart.version}
org.jfree
jcommon
${jcommon.version}
org.apache.httpcomponents
httpclient
4.3.5
ch.qos.logback
logback-classic
${logback.version}
org.bytedeco
javacv-platform
org.nd4j
nd4j-api
1.0.0-M2.1
org.springframework.boot
spring-boot-maven-plugin
org.apache.maven.plugins
maven-compiler-plugin
17
1个epoch训练耗时100s,准确率达97%,详见代码注释,框架的api做得还比较好用
package ai;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.io.Assert;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.io.File;
import java.util.Random;
@Slf4j
public class LeNetMNISTReLu {
private static final String DATASET_PATH_BASE = "D:\\";
public static void main(String[] args) throws Exception {
int height = 28;
int width = 28;
// 黑白图片通道只有一个
int channels = 1;
// 0-9十种数字
int outputNum = 10;
int batchSize = 64;
// 这里一个epoch耗时约100s,3次准确率99%
int nEpochs = 1;
Assert.isTrue(new File(DATASET_PATH_BASE + "/mnist_png").exists(), "请下载压缩包并解压到" + DATASET_PATH_BASE);
// 该label生成器会将数据所在父目录名作为label,要求目录名必须为数值,这里mnist数据集正好是放在0-9文件夹的
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
// 归一化(0-1)
DataNormalization normalization = new ImagePreProcessingScaler();
Random random = new Random(12345);
log.info("训练集6W张...");
File trainData = new File(DATASET_PATH_BASE + "/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, random);
ImageRecordReader trainRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
trainRecordReader.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, outputNum);
normalization.fit(trainIter);
trainIter.setPreProcessor(normalization); // 先像素归一化
log.info("验证集1W张...");
File validateData = new File(DATASET_PATH_BASE + "/mnist_png/testing");
FileSplit validateSplit = new FileSplit(validateData, NativeImageLoader.ALLOWED_FORMATS, random);
ImageRecordReader validateRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
validateRecordReader.initialize(validateSplit);
DataSetIterator validateIter = new RecordReaderDataSetIterator(validateRecordReader, batchSize, 1, outputNum);
validateIter.setPreProcessor(normalization);
// 训练集6W数据 每次迭代batchSize=64,故这里大概有1000次迭代
// 学习率,每200个迭代更新一次学习率(步长),先大一点,还可以每个Epoch更新一次学习率
MapSchedule mapSchedule = new MapSchedule.Builder(ScheduleType.ITERATION)
.add(0, 0.06)
.add(200, 0.05)
.add(600, 0.028)
.add(800, 0.006)
.add(1000, 0.001)
.build();
// 超参
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(1)
.l2(0.0005)
.updater(new Nesterovs(mapSchedule))
//.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) //该优化器导致长时间无法拟合
.weightInit(WeightInit.XAVIER)
.list()
.layer(new ConvolutionLayer.Builder(5, 5)
.nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new ConvolutionLayer.Builder(5, 5)
.stride(1, 1)
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
.build();
// 神经网络对象构建
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
// 训练监控,每次迭代打印损失函数值
net.setListeners(new ScoreIterationListener(10));
// WEB UI监控训练过程
//UIServer uiServer = UIServer.getInstance();
//FileStatsStorage statsStorage = new FileStatsStorage(new File("D:\\ai-webui.dat"));
//uiServer.attach(statsStorage);
//net.setListeners(new StatsListener(statsStorage));
log.info("网络参数个数{}", net.numParams());
long startTime = System.currentTimeMillis();
// 训练epochs轮
for (int i = 0; i < nEpochs; i++) {
log.info("Epoch=" + i);
net.fit(trainIter);
Evaluation eval = net.evaluate(validateIter);
log.info(eval.stats());
trainIter.reset();
validateIter.reset();
}
log.info("训练耗时{}毫秒", System.currentTimeMillis() - startTime);
// 保存模型
File ministModelPath = new File(DATASET_PATH_BASE + "/ministModel.zip");
ModelSerializer.writeModel(net, ministModelPath, true);
// 推理逻辑:加载网络(模型)——>加载测试图片——>预测
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(new File(DATASET_PATH_BASE + "/ministModel.zip"));
NativeImageLoader imageLoader = new NativeImageLoader(height, width, channels);
FileUtils.listFiles(new File("D:\\mnist_png\\testing"), null, true)
.parallelStream().forEach(file -> {
try {
INDArray matrix = imageLoader.asMatrix(file);
INDArray output = network.output(matrix);
// 取最可能的预测结果
int predictedValue = Nd4j.argMax(output, 1).getInt(0);
// 数字图片按数值放在每个文件夹的,故图片所在文件夹名即为真实值
String realValue = file.getParentFile().getName();
log.info("真实值:{},预测值:{}", realValue, predictedValue);
Assert.isTrue(predictedValue == Integer.parseInt(realValue), file.getAbsolutePath() + "预测错误");
} catch (Exception e) {
log.warn(e.getMessage(), e);
}
});
}
}
deeplearning4j-examples 参考其readme文档