机器学习笔记 - Java学习框架Deeplearning4j初体验

一、Deeplearning4j概述

        Deeplearning4j 是一套用于在JVM上运行深度学习的工具。它是唯一一个允许您从 java 训练模型,同时通过我们的 cpython 绑定、模型导入支持以及其他运行时(如 tensorflow-java 和 onnxruntime)的互操作的混合执行与 python 生态系统互操作的框架。

        用例包括导入和重新训练模型(Pytorch、Tensorflow、Keras)模型以及在 JVM 微服务环境、移动设备、物联网和 Apache Spark 中部署。这是对您的 python 环境的一个很好的补充,可以运行在 python 中构建的模型,部署到或打包用于其他环境。

        DL4J 生态系统中的所有项目都支持 Windows、Linux 和 macOS。硬件支持包括 CUDA GPU(10.0、10.1、10.2,OSX 除外)、x86 CPU(x86_64、avx2、avx512)、ARM CPU(arm、arm64、armhf)和 PowerPC(ppc64le)。

二、Deeplearning4j模块组成

        DL4J : 用于构建具有各种层的多层网络和计算图的高级 API,包括自定义层。支持从 h5 导入 Keras 模型,包括 tf.keras 模型(截至 1.0.0-M2),还支持在 Apache Spark 上进行分布式训练。

        ND4J:通用线性代数库,包含超过 500 种数学、线性代数和深度学习操作。ND4J 基于高度优化的 C++ 代码库 LibND4J,通过 OpenBLAS、OneDNN (MKL-DNN)、cuDNN、cuBLAS 等库提供 CPU (AVX2/512) 和 GPU (CUDA) 支持和加速

        SameDiff:ND4J 库的一部分,SameDiff 是我们的自动微分/深度学习框架。SameDiff 使用基于图形(定义然后运行)的方法,类似于 TensorFlow 图形模式。Eager graph (TensorFlow 2.x eager/PyTorch) 图执行计划。SameDiff 支持导入 TensorFlow 冻结模型格式 .pb (protobuf) 模型。计划导入 ONNX、TensorFlow SavedModel 和 Keras 模型。Deeplearning4j 还具有完整的 SameDiff 支持,可以轻松编写自定义层和损失函数。

        DataVec:用于各种格式和文件(HDFS、Spark、图像、视频、音频、CSV、Excel 等)的机器学习数据的 ETL

        Arbiter:超参数搜索库

        LibND4J:支撑一切的 C++ 库。有关 JVM 如何访问本机数组和操作的更多信息,请参阅JavaCPP。

三、Maven中配置Deeplearning4j



	4.0.0
	
		org.springframework.boot
		spring-boot-starter-parent
		2.6.4
		 
	
	com.algorithm
	demo
	0.0.1-SNAPSHOT
	demo
	Demo project for Spring Boot
	
		1.0.0-M2
		1.8
	
	
		
		
			org.deeplearning4j
			deeplearning4j-core
			${dl4j-master.version}
		
		
			org.nd4j
			nd4j-native
			${dl4j-master.version}
		
		
			org.springframework.boot
			spring-boot-starter
		
		
			jfree
			jfreechart
			1.0.13
		
		
			org.springframework.boot
			spring-boot-starter-test
			test
		
	

	
		
			
				org.springframework.boot
				spring-boot-maven-plugin
			
		
	


四、线性数据分类示例

1、参考代码

        LinearDataClassifier.java

package com.algorithm.demo.dl4jexamples;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.examples.utils.DownloaderUtility;
import org.deeplearning4j.examples.utils.PlotUtil;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;

import java.io.File;
import java.util.concurrent.TimeUnit;

/**
 * "Linear" Data Classification Example
 * 
 * Based on the data from Jason Baldridge:
 * https://github.com/jasonbaldridge/try-tf/tree/master/simdata
 *
 * @author Josh Patterson
 * @author Alex Black (added plots)
 */
@SuppressWarnings("DuplicatedCode")
public class LinearDataClassifier {

    public static boolean visualize = true;
    public static String dataLocalPath;

    public static void main(String[] args) throws Exception {
        int seed = 123;
        double learningRate = 0.01;
        int batchSize = 50;
        int nEpochs = 30;

        int numInputs = 2;
        int numOutputs = 2;
        int numHiddenNodes = 20;

        dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();

        //加载训练数据
        RecordReader rr = new CSVRecordReader();
        rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
        DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);

        //加载验证数据
        RecordReader rrTest = new CSVRecordReader();
        rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
        DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);

        //创建多层网络配置
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .weightInit(WeightInit.XAVIER)
                .updater(new Nesterovs(learningRate, 0.9))
                .list()
                .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .activation(Activation.RELU)
                        .build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .activation(Activation.SOFTMAX)
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .build();

        //网络初始化
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(10));  //Print score every 10 parameter updates
        //进行训练
        model.fit(trainIter, nEpochs);

        //进行验证
        System.out.println("Evaluate model....");
        Evaluation eval = new Evaluation(numOutputs);
        while (testIter.hasNext()) {
            DataSet t = testIter.next();
            INDArray features = t.getFeatures();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }
        //An alternate way to do the above loop
        //Evaluation evalResults = model.evaluate(testIter);

        //Print the evaluation statistics
        System.out.println(eval.stats());

        System.out.println("\n****************Example finished********************");
        //训练完成

        //以下代码仅用于绘制数据和预测可视化
        generateVisuals(model, trainIter, testIter);
    }

    public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
        if (visualize) {
            double xMin = 0;
            double xMax = 1.0;
            double yMin = -0.2;
            double yMax = 0.8;
            int nPointsPerAxis = 100;

            //Generate x,y points that span the whole range of features
            INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
            //Get train data and plot with predictions
            PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
            TimeUnit.SECONDS.sleep(3);
            //Get test data, run the test data through the network to generate predictions, and plot those predictions:
            PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
        }
    }
}

        PlotUtil.java,绘图工具

package com.algorithm.demo.dl4jexamples.utils;

import org.deeplearning4j.datasets.iterator.utilty.ListDataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.AxisLocation;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.block.BlockBorder;
import org.jfree.chart.plot.DatasetRenderingOrder;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.GrayPaintScale;
import org.jfree.chart.renderer.PaintScale;
import org.jfree.chart.renderer.xy.XYBlockRenderer;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.chart.title.PaintScaleLegend;
import org.jfree.data.xy.*;
import org.jfree.ui.RectangleEdge;
import org.jfree.ui.RectangleInsets;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;

/**
 * Simple plotting methods for the MLPClassifier quickstartexamples
 *
 * @author Alex Black
 */
public class PlotUtil {

    /**
     * Plot the training data. Assume 2d input, classification output
     *
     * @param model         Model to use to get predictions
     * @param trainIter     DataSet Iterator
     * @param backgroundIn  sets of x,y points in input space, plotted in the background
     * @param nDivisions    Number of points (per axis, for the backgroundIn/backgroundOut arrays)
     */
    public static void plotTrainingData(MultiLayerNetwork model, DataSetIterator trainIter, INDArray backgroundIn, int nDivisions) {
        double[] mins = backgroundIn.min(0).data().asDouble();
        double[] maxs = backgroundIn.max(0).data().asDouble();

        DataSet ds = allBatches(trainIter);
        INDArray backgroundOut = model.output(backgroundIn);

        XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
        JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTrain(ds.getFeatures(), ds.getLabels())));

        JFrame f = new JFrame();
        f.add(panel);
        f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
        f.pack();
        f.setTitle("Training Data");

        f.setVisible(true);
        f.setLocation(0, 0);
    }

    /**
     * Plot the training data. Assume 2d input, classification output
     *
     * @param model         Model to use to get predictions
     * @param testIter      Test Iterator
     * @param backgroundIn  sets of x,y points in input space, plotted in the background
     * @param nDivisions    Number of points (per axis, for the backgroundIn/backgroundOut arrays)
     */
    public static void plotTestData(MultiLayerNetwork model, DataSetIterator testIter, INDArray backgroundIn, int nDivisions) {

        double[] mins = backgroundIn.min(0).data().asDouble();
        double[] maxs = backgroundIn.max(0).data().asDouble();

        INDArray backgroundOut = model.output(backgroundIn);
        XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
        DataSet ds = allBatches(testIter);
        INDArray predicted = model.output(ds.getFeatures());
        JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTest(ds.getFeatures(), ds.getLabels(), predicted)));

        JFrame f = new JFrame();
        f.add(panel);
        f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
        f.pack();
        f.setTitle("Test Data");

        f.setVisible(true);
        f.setLocationRelativeTo(null);
        //f.setLocation(100,100);

    }


    /**
     * Create data for the background data set
     */
    private static XYZDataset createBackgroundData(INDArray backgroundIn, INDArray backgroundOut) {
        int nRows = backgroundIn.rows();
        double[] xValues = new double[nRows];
        double[] yValues = new double[nRows];
        double[] zValues = new double[nRows];
        for (int i = 0; i < nRows; i++) {
            xValues[i] = backgroundIn.getDouble(i, 0);
            yValues[i] = backgroundIn.getDouble(i, 1);
            zValues[i] = backgroundOut.getDouble(i, 0);

        }

        DefaultXYZDataset dataset = new DefaultXYZDataset();
        dataset.addSeries("Series 1",
                new double[][]{xValues, yValues, zValues});
        return dataset;
    }

    //Training data
    private static XYDataset createDataSetTrain(INDArray features, INDArray labels) {
        int nRows = features.rows();

        int nClasses = 2; // Binary classification using one output call end sigmoid.

        XYSeries[] series = new XYSeries[nClasses];
        for (int i = 0; i < series.length; i++) series[i] = new XYSeries("Class " + i);
        INDArray argMax = Nd4j.getExecutioner().exec(new ArgMax(new INDArray[]{labels},false,new int[]{1}))[0];
        for (int i = 0; i < nRows; i++) {
            int classIdx = (int) argMax.getDouble(i);
            series[classIdx].add(features.getDouble(i, 0), features.getDouble(i, 1));
        }

        XYSeriesCollection c = new XYSeriesCollection();
        for (XYSeries s : series) c.addSeries(s);
        return c;
    }

    //Test data
    private static XYDataset createDataSetTest(INDArray features, INDArray labels, INDArray predicted) {
        int nRows = features.rows();

        int nClasses = 2; // Binary classification using one output call end sigmoid.

        XYSeries[] series = new XYSeries[nClasses * nClasses];
        int[] series_index = new int[]{0, 3, 2, 1}; //little hack to make the charts look consistent.
        for (int i = 0; i < nClasses * nClasses; i++) {
            int trueClass = i / nClasses;
            int predClass = i % nClasses;
            String label = "actual=" + trueClass + ", pred=" + predClass;
            series[series_index[i]] = new XYSeries(label);
        }
        INDArray actualIdx = labels.argMax(1);
        INDArray predictedIdx = predicted.argMax(1);
        for (int i = 0; i < nRows; i++) {
            int classIdx = actualIdx.getInt(i);
            int predIdx = predictedIdx.getInt(i);
            int idx = series_index[classIdx * nClasses + predIdx];
            series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
        }

        XYSeriesCollection c = new XYSeriesCollection();
        for (XYSeries s : series) c.addSeries(s);
        return c;
    }

    private static JFreeChart createChart(XYZDataset dataset, double[] mins, double[] maxs, int nPoints, XYDataset xyData) {
        NumberAxis xAxis = new NumberAxis("X");
        xAxis.setRange(mins[0], maxs[0]);


        NumberAxis yAxis = new NumberAxis("Y");
        yAxis.setRange(mins[1], maxs[1]);

        XYBlockRenderer renderer = new XYBlockRenderer();
        renderer.setBlockWidth((maxs[0] - mins[0]) / (nPoints - 1));
        renderer.setBlockHeight((maxs[1] - mins[1]) / (nPoints - 1));
        PaintScale scale = new GrayPaintScale(0, 1.0);
        renderer.setPaintScale(scale);
        XYPlot plot = new XYPlot(dataset, xAxis, yAxis, renderer);
        plot.setBackgroundPaint(Color.lightGray);
        plot.setDomainGridlinesVisible(false);
        plot.setRangeGridlinesVisible(false);
        plot.setAxisOffset(new RectangleInsets(5, 5, 5, 5));
        JFreeChart chart = new JFreeChart("", plot);
        chart.getXYPlot().getRenderer().setSeriesVisibleInLegend(0, false);


        NumberAxis scaleAxis = new NumberAxis("Probability (class 1)");
        scaleAxis.setAxisLinePaint(Color.white);
        scaleAxis.setTickMarkPaint(Color.white);
        scaleAxis.setTickLabelFont(new Font("Dialog", Font.PLAIN, 7));
        PaintScaleLegend legend = new PaintScaleLegend(new GrayPaintScale(),
                scaleAxis);
        legend.setStripOutlineVisible(false);
        legend.setSubdivisionCount(20);
        legend.setAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
        legend.setAxisOffset(5.0);
        legend.setMargin(new RectangleInsets(5, 5, 5, 5));
        legend.setFrame(new BlockBorder(Color.red));
        legend.setPadding(new RectangleInsets(10, 10, 10, 10));
        legend.setStripWidth(10);
        legend.setPosition(RectangleEdge.LEFT);
        chart.addSubtitle(legend);

        ChartUtilities.applyCurrentTheme(chart);

        plot.setDataset(1, xyData);
        XYLineAndShapeRenderer renderer2 = new XYLineAndShapeRenderer();
        renderer2.setBaseLinesVisible(false);
        plot.setRenderer(1, renderer2);

        plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);

        return chart;
    }

    public static INDArray generatePointsOnGraph(double xMin, double xMax, double yMin, double yMax, int nPointsPerAxis) {
        //generate all the x,y points
        double[][] evalPoints = new double[nPointsPerAxis * nPointsPerAxis][2];
        int count = 0;
        for (int i = 0; i < nPointsPerAxis; i++) {
            for (int j = 0; j < nPointsPerAxis; j++) {
                double x = i * (xMax - xMin) / (nPointsPerAxis - 1) + xMin;
                double y = j * (yMax - yMin) / (nPointsPerAxis - 1) + yMin;

                evalPoints[count][0] = x;
                evalPoints[count][1] = y;

                count++;
            }
        }

        return Nd4j.create(evalPoints);
    }

    /**
     * This is to collect all the data and return it as one minibatch. Obviously only for use here with small datasets
     * @param iter
     * @return
     */
    private static DataSet allBatches(DataSetIterator iter) {

        List fullSet = new ArrayList<>();
        iter.reset();
        while (iter.hasNext()) {
            List miniBatchList = iter.next().asList();
            fullSet.addAll(miniBatchList);
        }
        iter.reset();
        return new ListDataSetIterator<>(fullSet,fullSet.size()).next();
    }

}
        DownloaderUtility.java,下载工具类
package com.algorithm.demo.dl4jexamples.utils;

import org.apache.commons.io.FilenameUtils;
import org.nd4j.common.resources.Downloader;

import java.io.File;
import java.net.URL;

/**
 * Given a base url and a zipped file name downloads contents to a specified directory under ~/dl4j-examples-data
 * Will check md5 sum of downloaded file
 * 

* * Sample Usage with an instantiation DATAEXAMPLE(baseurl,"DataExamples.zip","data-dir",md5,size): * * DATAEXAMPLE.Download() & DATAEXAMPLE.Download(true) * Will download DataExamples.zip from baseurl/DataExamples.zip to a temp directory, * Unzip it to ~/dl4j-example-data/data-dir * Return the string "~/dl4j-example-data/data-dir/DataExamples" * * DATAEXAMPLE.Download(false) * will perform the same download and unzip as above * But returns the string "~/dl4j-example-data/data-dir" instead * * * @author susaneraly */ public enum DownloaderUtility { IRISDATA("IrisData.zip", "datavec-examples", "bb49e38bb91089634d7ef37ad8e430b8", "1KB"), ANIMALS("animals.zip", "dl4j-examples", "1976a1f2b61191d2906e4f615246d63e", "820KB"), ANOMALYSEQUENCEDATA("anomalysequencedata.zip", "dl4j-examples", "51bb7c50e265edec3a241a2d7cce0e73", "3MB"), CAPTCHAIMAGE("captchaImage.zip", "dl4j-examples", "1d159c9587fdbb1cbfd66f0d62380e61", "42MB"), CLASSIFICATIONDATA("classification.zip", "dl4j-examples", "dba31e5838fe15993579edbf1c60c355", "77KB"), DATAEXAMPLES("DataExamples.zip", "dl4j-examples", "e4de9c6f19aaae21fed45bfe2a730cbb", "2MB"), LOTTERYDATA("lottery.zip", "dl4j-examples", "1e54ac1210e39c948aa55417efee193a", "2MB"), NEWSDATA("NewsData.zip", "dl4j-examples", "0d08e902faabe6b8bfe5ecdd78af9f64", "21MB"), NLPDATA("nlp.zip", "dl4j-examples", "1ac7cd7ca08f13402f0e3b83e20c0512", "91MB"), PREDICTGENDERDATA("PredictGender.zip", "dl4j-examples", "42a3fec42afa798217e0b8687667257e", "3MB"), STYLETRANSFER("styletransfer.zip", "dl4j-examples", "b2b90834d667679d7ee3dfb1f40abe94", "3MB"), VIDEOEXAMPLE("video.zip","dl4j-examples", "56274eb6329a848dce3e20631abc6752", "8.5MB"); private final String BASE_URL; private final String DATA_FOLDER; private final String ZIP_FILE; private final String MD5; private final String DATA_SIZE; private static final String AZURE_BLOB_URL = "https://dl4jdata.blob.core.windows.net/dl4j-examples"; /** * For use with resources uploaded to Azure blob storage. * * @param zipFile Name of zipfile. Should be a zip of a single directory with the same name * @param dataFolder The folder to extract to under ~/dl4j-examples-data * @param md5 of zipfile * @param dataSize of zipfile */ DownloaderUtility(String zipFile, String dataFolder, String md5, String dataSize) { this(AZURE_BLOB_URL + "/" + dataFolder, zipFile, dataFolder, md5, dataSize); } /** * Downloads a zip file from a base url to a specified directory under the user's home directory * * @param baseURL URL of file * @param zipFile Name of zipfile to download from baseURL i.e baseURL+"/"+zipFile gives full URL * @param dataFolder The folder to extract to under ~/dl4j-examples-data * @param md5 of zipfile * @param dataSize of zipfile */ DownloaderUtility(String baseURL, String zipFile, String dataFolder, String md5, String dataSize) { BASE_URL = baseURL; DATA_FOLDER = dataFolder; ZIP_FILE = zipFile; MD5 = md5; DATA_SIZE = dataSize; } public String Download() throws Exception { return Download(true); } public String Download(boolean returnSubFolder) throws Exception { String dataURL = BASE_URL + "/" + ZIP_FILE; String downloadPath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), ZIP_FILE); String extractDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/" + DATA_FOLDER); if (!new File(extractDir).exists()) new File(extractDir).mkdirs(); String dataPathLocal = extractDir; if (returnSubFolder) { String resourceName = ZIP_FILE.substring(0, ZIP_FILE.lastIndexOf(".zip")); dataPathLocal = FilenameUtils.concat(extractDir, resourceName); } int downloadRetries = 10; if (!new File(dataPathLocal).exists() || new File(dataPathLocal).list().length == 0) { System.out.println("_______________________________________________________________________"); System.out.println("Downloading data (" + DATA_SIZE + ") and extracting to \n\t" + dataPathLocal); System.out.println("_______________________________________________________________________"); Downloader.downloadAndExtract("files", new URL(dataURL), new File(downloadPath), new File(extractDir), MD5, downloadRetries); } else { System.out.println("_______________________________________________________________________"); System.out.println("Example data present in \n\t" + dataPathLocal); System.out.println("_______________________________________________________________________"); } return dataPathLocal; } }

2、运行结果

机器学习笔记 - Java学习框架Deeplearning4j初体验_第1张图片

         对熟悉Java的人来说使用起来还是很舒服的。

你可能感兴趣的:(#,机器/深度学习案例,Deeplearning4j,Java和深度学习,深度学习,Apache,Spark,jvm)