Deeplearning4j 是一套用于在JVM上运行深度学习的工具。它是唯一一个允许您从 java 训练模型,同时通过我们的 cpython 绑定、模型导入支持以及其他运行时(如 tensorflow-java 和 onnxruntime)的互操作的混合执行与 python 生态系统互操作的框架。
用例包括导入和重新训练模型(Pytorch、Tensorflow、Keras)模型以及在 JVM 微服务环境、移动设备、物联网和 Apache Spark 中部署。这是对您的 python 环境的一个很好的补充,可以运行在 python 中构建的模型,部署到或打包用于其他环境。
DL4J 生态系统中的所有项目都支持 Windows、Linux 和 macOS。硬件支持包括 CUDA GPU(10.0、10.1、10.2,OSX 除外)、x86 CPU(x86_64、avx2、avx512)、ARM CPU(arm、arm64、armhf)和 PowerPC(ppc64le)。
DL4J : 用于构建具有各种层的多层网络和计算图的高级 API,包括自定义层。支持从 h5 导入 Keras 模型,包括 tf.keras 模型(截至 1.0.0-M2),还支持在 Apache Spark 上进行分布式训练。
ND4J:通用线性代数库,包含超过 500 种数学、线性代数和深度学习操作。ND4J 基于高度优化的 C++ 代码库 LibND4J,通过 OpenBLAS、OneDNN (MKL-DNN)、cuDNN、cuBLAS 等库提供 CPU (AVX2/512) 和 GPU (CUDA) 支持和加速
SameDiff:ND4J 库的一部分,SameDiff 是我们的自动微分/深度学习框架。SameDiff 使用基于图形(定义然后运行)的方法,类似于 TensorFlow 图形模式。Eager graph (TensorFlow 2.x eager/PyTorch) 图执行计划。SameDiff 支持导入 TensorFlow 冻结模型格式 .pb (protobuf) 模型。计划导入 ONNX、TensorFlow SavedModel 和 Keras 模型。Deeplearning4j 还具有完整的 SameDiff 支持,可以轻松编写自定义层和损失函数。
DataVec:用于各种格式和文件(HDFS、Spark、图像、视频、音频、CSV、Excel 等)的机器学习数据的 ETL
Arbiter:超参数搜索库
LibND4J:支撑一切的 C++ 库。有关 JVM 如何访问本机数组和操作的更多信息,请参阅JavaCPP。
4.0.0
org.springframework.boot
spring-boot-starter-parent
2.6.4
com.algorithm
demo
0.0.1-SNAPSHOT
demo
Demo project for Spring Boot
1.0.0-M2
1.8
org.deeplearning4j
deeplearning4j-core
${dl4j-master.version}
org.nd4j
nd4j-native
${dl4j-master.version}
org.springframework.boot
spring-boot-starter
jfree
jfreechart
1.0.13
org.springframework.boot
spring-boot-starter-test
test
org.springframework.boot
spring-boot-maven-plugin
LinearDataClassifier.java
package com.algorithm.demo.dl4jexamples;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.examples.utils.DownloaderUtility;
import org.deeplearning4j.examples.utils.PlotUtil;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.io.File;
import java.util.concurrent.TimeUnit;
/**
* "Linear" Data Classification Example
*
* Based on the data from Jason Baldridge:
* https://github.com/jasonbaldridge/try-tf/tree/master/simdata
*
* @author Josh Patterson
* @author Alex Black (added plots)
*/
@SuppressWarnings("DuplicatedCode")
public class LinearDataClassifier {
public static boolean visualize = true;
public static String dataLocalPath;
public static void main(String[] args) throws Exception {
int seed = 123;
double learningRate = 0.01;
int batchSize = 50;
int nEpochs = 30;
int numInputs = 2;
int numOutputs = 2;
int numHiddenNodes = 20;
dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();
//加载训练数据
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);
//加载验证数据
RecordReader rrTest = new CSVRecordReader();
rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);
//创建多层网络配置
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(learningRate, 0.9))
.list()
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(numHiddenNodes).nOut(numOutputs).build())
.build();
//网络初始化
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
//进行训练
model.fit(trainIter, nEpochs);
//进行验证
System.out.println("Evaluate model....");
Evaluation eval = new Evaluation(numOutputs);
while (testIter.hasNext()) {
DataSet t = testIter.next();
INDArray features = t.getFeatures();
INDArray labels = t.getLabels();
INDArray predicted = model.output(features, false);
eval.eval(labels, predicted);
}
//An alternate way to do the above loop
//Evaluation evalResults = model.evaluate(testIter);
//Print the evaluation statistics
System.out.println(eval.stats());
System.out.println("\n****************Example finished********************");
//训练完成
//以下代码仅用于绘制数据和预测可视化
generateVisuals(model, trainIter, testIter);
}
public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
if (visualize) {
double xMin = 0;
double xMax = 1.0;
double yMin = -0.2;
double yMax = 0.8;
int nPointsPerAxis = 100;
//Generate x,y points that span the whole range of features
INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
//Get train data and plot with predictions
PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
TimeUnit.SECONDS.sleep(3);
//Get test data, run the test data through the network to generate predictions, and plot those predictions:
PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
}
}
}
PlotUtil.java,绘图工具
package com.algorithm.demo.dl4jexamples.utils;
import org.deeplearning4j.datasets.iterator.utilty.ListDataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.AxisLocation;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.block.BlockBorder;
import org.jfree.chart.plot.DatasetRenderingOrder;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.GrayPaintScale;
import org.jfree.chart.renderer.PaintScale;
import org.jfree.chart.renderer.xy.XYBlockRenderer;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.chart.title.PaintScaleLegend;
import org.jfree.data.xy.*;
import org.jfree.ui.RectangleEdge;
import org.jfree.ui.RectangleInsets;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;
/**
* Simple plotting methods for the MLPClassifier quickstartexamples
*
* @author Alex Black
*/
public class PlotUtil {
/**
* Plot the training data. Assume 2d input, classification output
*
* @param model Model to use to get predictions
* @param trainIter DataSet Iterator
* @param backgroundIn sets of x,y points in input space, plotted in the background
* @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
*/
public static void plotTrainingData(MultiLayerNetwork model, DataSetIterator trainIter, INDArray backgroundIn, int nDivisions) {
double[] mins = backgroundIn.min(0).data().asDouble();
double[] maxs = backgroundIn.max(0).data().asDouble();
DataSet ds = allBatches(trainIter);
INDArray backgroundOut = model.output(backgroundIn);
XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTrain(ds.getFeatures(), ds.getLabels())));
JFrame f = new JFrame();
f.add(panel);
f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
f.pack();
f.setTitle("Training Data");
f.setVisible(true);
f.setLocation(0, 0);
}
/**
* Plot the training data. Assume 2d input, classification output
*
* @param model Model to use to get predictions
* @param testIter Test Iterator
* @param backgroundIn sets of x,y points in input space, plotted in the background
* @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
*/
public static void plotTestData(MultiLayerNetwork model, DataSetIterator testIter, INDArray backgroundIn, int nDivisions) {
double[] mins = backgroundIn.min(0).data().asDouble();
double[] maxs = backgroundIn.max(0).data().asDouble();
INDArray backgroundOut = model.output(backgroundIn);
XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
DataSet ds = allBatches(testIter);
INDArray predicted = model.output(ds.getFeatures());
JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTest(ds.getFeatures(), ds.getLabels(), predicted)));
JFrame f = new JFrame();
f.add(panel);
f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
f.pack();
f.setTitle("Test Data");
f.setVisible(true);
f.setLocationRelativeTo(null);
//f.setLocation(100,100);
}
/**
* Create data for the background data set
*/
private static XYZDataset createBackgroundData(INDArray backgroundIn, INDArray backgroundOut) {
int nRows = backgroundIn.rows();
double[] xValues = new double[nRows];
double[] yValues = new double[nRows];
double[] zValues = new double[nRows];
for (int i = 0; i < nRows; i++) {
xValues[i] = backgroundIn.getDouble(i, 0);
yValues[i] = backgroundIn.getDouble(i, 1);
zValues[i] = backgroundOut.getDouble(i, 0);
}
DefaultXYZDataset dataset = new DefaultXYZDataset();
dataset.addSeries("Series 1",
new double[][]{xValues, yValues, zValues});
return dataset;
}
//Training data
private static XYDataset createDataSetTrain(INDArray features, INDArray labels) {
int nRows = features.rows();
int nClasses = 2; // Binary classification using one output call end sigmoid.
XYSeries[] series = new XYSeries[nClasses];
for (int i = 0; i < series.length; i++) series[i] = new XYSeries("Class " + i);
INDArray argMax = Nd4j.getExecutioner().exec(new ArgMax(new INDArray[]{labels},false,new int[]{1}))[0];
for (int i = 0; i < nRows; i++) {
int classIdx = (int) argMax.getDouble(i);
series[classIdx].add(features.getDouble(i, 0), features.getDouble(i, 1));
}
XYSeriesCollection c = new XYSeriesCollection();
for (XYSeries s : series) c.addSeries(s);
return c;
}
//Test data
private static XYDataset createDataSetTest(INDArray features, INDArray labels, INDArray predicted) {
int nRows = features.rows();
int nClasses = 2; // Binary classification using one output call end sigmoid.
XYSeries[] series = new XYSeries[nClasses * nClasses];
int[] series_index = new int[]{0, 3, 2, 1}; //little hack to make the charts look consistent.
for (int i = 0; i < nClasses * nClasses; i++) {
int trueClass = i / nClasses;
int predClass = i % nClasses;
String label = "actual=" + trueClass + ", pred=" + predClass;
series[series_index[i]] = new XYSeries(label);
}
INDArray actualIdx = labels.argMax(1);
INDArray predictedIdx = predicted.argMax(1);
for (int i = 0; i < nRows; i++) {
int classIdx = actualIdx.getInt(i);
int predIdx = predictedIdx.getInt(i);
int idx = series_index[classIdx * nClasses + predIdx];
series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
}
XYSeriesCollection c = new XYSeriesCollection();
for (XYSeries s : series) c.addSeries(s);
return c;
}
private static JFreeChart createChart(XYZDataset dataset, double[] mins, double[] maxs, int nPoints, XYDataset xyData) {
NumberAxis xAxis = new NumberAxis("X");
xAxis.setRange(mins[0], maxs[0]);
NumberAxis yAxis = new NumberAxis("Y");
yAxis.setRange(mins[1], maxs[1]);
XYBlockRenderer renderer = new XYBlockRenderer();
renderer.setBlockWidth((maxs[0] - mins[0]) / (nPoints - 1));
renderer.setBlockHeight((maxs[1] - mins[1]) / (nPoints - 1));
PaintScale scale = new GrayPaintScale(0, 1.0);
renderer.setPaintScale(scale);
XYPlot plot = new XYPlot(dataset, xAxis, yAxis, renderer);
plot.setBackgroundPaint(Color.lightGray);
plot.setDomainGridlinesVisible(false);
plot.setRangeGridlinesVisible(false);
plot.setAxisOffset(new RectangleInsets(5, 5, 5, 5));
JFreeChart chart = new JFreeChart("", plot);
chart.getXYPlot().getRenderer().setSeriesVisibleInLegend(0, false);
NumberAxis scaleAxis = new NumberAxis("Probability (class 1)");
scaleAxis.setAxisLinePaint(Color.white);
scaleAxis.setTickMarkPaint(Color.white);
scaleAxis.setTickLabelFont(new Font("Dialog", Font.PLAIN, 7));
PaintScaleLegend legend = new PaintScaleLegend(new GrayPaintScale(),
scaleAxis);
legend.setStripOutlineVisible(false);
legend.setSubdivisionCount(20);
legend.setAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
legend.setAxisOffset(5.0);
legend.setMargin(new RectangleInsets(5, 5, 5, 5));
legend.setFrame(new BlockBorder(Color.red));
legend.setPadding(new RectangleInsets(10, 10, 10, 10));
legend.setStripWidth(10);
legend.setPosition(RectangleEdge.LEFT);
chart.addSubtitle(legend);
ChartUtilities.applyCurrentTheme(chart);
plot.setDataset(1, xyData);
XYLineAndShapeRenderer renderer2 = new XYLineAndShapeRenderer();
renderer2.setBaseLinesVisible(false);
plot.setRenderer(1, renderer2);
plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);
return chart;
}
public static INDArray generatePointsOnGraph(double xMin, double xMax, double yMin, double yMax, int nPointsPerAxis) {
//generate all the x,y points
double[][] evalPoints = new double[nPointsPerAxis * nPointsPerAxis][2];
int count = 0;
for (int i = 0; i < nPointsPerAxis; i++) {
for (int j = 0; j < nPointsPerAxis; j++) {
double x = i * (xMax - xMin) / (nPointsPerAxis - 1) + xMin;
double y = j * (yMax - yMin) / (nPointsPerAxis - 1) + yMin;
evalPoints[count][0] = x;
evalPoints[count][1] = y;
count++;
}
}
return Nd4j.create(evalPoints);
}
/**
* This is to collect all the data and return it as one minibatch. Obviously only for use here with small datasets
* @param iter
* @return
*/
private static DataSet allBatches(DataSetIterator iter) {
List fullSet = new ArrayList<>();
iter.reset();
while (iter.hasNext()) {
List miniBatchList = iter.next().asList();
fullSet.addAll(miniBatchList);
}
iter.reset();
return new ListDataSetIterator<>(fullSet,fullSet.size()).next();
}
}
DownloaderUtility.java,下载工具类
package com.algorithm.demo.dl4jexamples.utils; import org.apache.commons.io.FilenameUtils; import org.nd4j.common.resources.Downloader; import java.io.File; import java.net.URL; /** * Given a base url and a zipped file name downloads contents to a specified directory under ~/dl4j-examples-data * Will check md5 sum of downloaded file *
* * Sample Usage with an instantiation DATAEXAMPLE(baseurl,"DataExamples.zip","data-dir",md5,size): * * DATAEXAMPLE.Download() & DATAEXAMPLE.Download(true) * Will download DataExamples.zip from baseurl/DataExamples.zip to a temp directory, * Unzip it to ~/dl4j-example-data/data-dir * Return the string "~/dl4j-example-data/data-dir/DataExamples" * * DATAEXAMPLE.Download(false) * will perform the same download and unzip as above * But returns the string "~/dl4j-example-data/data-dir" instead * * * @author susaneraly */ public enum DownloaderUtility { IRISDATA("IrisData.zip", "datavec-examples", "bb49e38bb91089634d7ef37ad8e430b8", "1KB"), ANIMALS("animals.zip", "dl4j-examples", "1976a1f2b61191d2906e4f615246d63e", "820KB"), ANOMALYSEQUENCEDATA("anomalysequencedata.zip", "dl4j-examples", "51bb7c50e265edec3a241a2d7cce0e73", "3MB"), CAPTCHAIMAGE("captchaImage.zip", "dl4j-examples", "1d159c9587fdbb1cbfd66f0d62380e61", "42MB"), CLASSIFICATIONDATA("classification.zip", "dl4j-examples", "dba31e5838fe15993579edbf1c60c355", "77KB"), DATAEXAMPLES("DataExamples.zip", "dl4j-examples", "e4de9c6f19aaae21fed45bfe2a730cbb", "2MB"), LOTTERYDATA("lottery.zip", "dl4j-examples", "1e54ac1210e39c948aa55417efee193a", "2MB"), NEWSDATA("NewsData.zip", "dl4j-examples", "0d08e902faabe6b8bfe5ecdd78af9f64", "21MB"), NLPDATA("nlp.zip", "dl4j-examples", "1ac7cd7ca08f13402f0e3b83e20c0512", "91MB"), PREDICTGENDERDATA("PredictGender.zip", "dl4j-examples", "42a3fec42afa798217e0b8687667257e", "3MB"), STYLETRANSFER("styletransfer.zip", "dl4j-examples", "b2b90834d667679d7ee3dfb1f40abe94", "3MB"), VIDEOEXAMPLE("video.zip","dl4j-examples", "56274eb6329a848dce3e20631abc6752", "8.5MB"); private final String BASE_URL; private final String DATA_FOLDER; private final String ZIP_FILE; private final String MD5; private final String DATA_SIZE; private static final String AZURE_BLOB_URL = "https://dl4jdata.blob.core.windows.net/dl4j-examples"; /** * For use with resources uploaded to Azure blob storage. * * @param zipFile Name of zipfile. Should be a zip of a single directory with the same name * @param dataFolder The folder to extract to under ~/dl4j-examples-data * @param md5 of zipfile * @param dataSize of zipfile */ DownloaderUtility(String zipFile, String dataFolder, String md5, String dataSize) { this(AZURE_BLOB_URL + "/" + dataFolder, zipFile, dataFolder, md5, dataSize); } /** * Downloads a zip file from a base url to a specified directory under the user's home directory * * @param baseURL URL of file * @param zipFile Name of zipfile to download from baseURL i.e baseURL+"/"+zipFile gives full URL * @param dataFolder The folder to extract to under ~/dl4j-examples-data * @param md5 of zipfile * @param dataSize of zipfile */ DownloaderUtility(String baseURL, String zipFile, String dataFolder, String md5, String dataSize) { BASE_URL = baseURL; DATA_FOLDER = dataFolder; ZIP_FILE = zipFile; MD5 = md5; DATA_SIZE = dataSize; } public String Download() throws Exception { return Download(true); } public String Download(boolean returnSubFolder) throws Exception { String dataURL = BASE_URL + "/" + ZIP_FILE; String downloadPath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), ZIP_FILE); String extractDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/" + DATA_FOLDER); if (!new File(extractDir).exists()) new File(extractDir).mkdirs(); String dataPathLocal = extractDir; if (returnSubFolder) { String resourceName = ZIP_FILE.substring(0, ZIP_FILE.lastIndexOf(".zip")); dataPathLocal = FilenameUtils.concat(extractDir, resourceName); } int downloadRetries = 10; if (!new File(dataPathLocal).exists() || new File(dataPathLocal).list().length == 0) { System.out.println("_______________________________________________________________________"); System.out.println("Downloading data (" + DATA_SIZE + ") and extracting to \n\t" + dataPathLocal); System.out.println("_______________________________________________________________________"); Downloader.downloadAndExtract("files", new URL(dataURL), new File(downloadPath), new File(extractDir), MD5, downloadRetries); } else { System.out.println("_______________________________________________________________________"); System.out.println("Example data present in \n\t" + dataPathLocal); System.out.println("_______________________________________________________________________"); } return dataPathLocal; } }
对熟悉Java的人来说使用起来还是很舒服的。